module utils
function set_zerosum_gauge
set_zerosum_gauge(params: Dict[str, Tensor]) → Dict[str, Tensor]
Sets the zero-sum gauge on the coupling matrix.
Args:
params
(Dict[str, torch.Tensor]): Parameters of the model.
Returns:
Dict[str, torch.Tensor]
: Parameters with fixed gauge.
function init_parameters
init_parameters(fi: Tensor) → Dict[str, Tensor]
Initialize the parameters of the DCA model.
Args:
fi
(torch.Tensor): Single-point frequencies of the data.
Returns:
Dict[str, torch.Tensor]
: Parameters of the model.
function init_chains
init_chains(
num_chains: int,
L: int,
q: int,
device: device,
dtype: dtype = torch.float32,
fi: Tensor | None = None
) → Tensor
Initialize the chains of the DCA model. If 'fi' is provided, the chains are sampled from the profile model, otherwise they are sampled uniformly at random.
Args:
num_chains
(int): Number of parallel chains.L
(int): Length of the MSA.q
(int): Number of values that each residue can assume.device
(torch.device): Device where to store the chains.dtype
(torch.dtype, optional): Data type of the chains. Defaults to torch.float32.fi
(torch.Tensor | None, optional): Single-point frequencies. Defaults to None.
Returns:
torch.Tensor
: Initialized parallel chains in one-hot encoding format.
function get_mask_save
get_mask_save(L: int, q: int, device: device) → Tensor
Returns the mask to save the upper-triangular part of the coupling matrix.
Args:
L
(int): Length of the MSA.q
(int): Number of values that each residue can assume.device
(torch.device): Device where to store the mask.
Returns:
torch.Tensor
: Mask.
function resample_sequences
resample_sequences(data: Tensor, weights: Tensor, nextract: int) → Tensor
Extracts nextract sequences from data with replacement according to the weights.
Args:
data
(torch.Tensor): Data array.weights
(torch.Tensor): Weights of the sequences.nextract
(int): Number of sequences to be extracted.
Returns:
torch.Tensor
: Extracted sequences.
function get_device
get_device(device: str, message: bool = True) → device
Returns the device where to store the tensors.
Args:
device
(str): Device to be used.message
(bool, optional): Print the device. Defaults to True.
Returns:
torch.device
: Device.
function get_dtype
get_dtype(dtype: str) → dtype
Returns the data type of the tensors.
Args:
dtype
(str): Data type.
Returns:
torch.dtype
: Data type.
This file was automatically generated via lazydocs.