module adabmDCA.utils
function init_parameters
init_parameters(fi: Tensor) → Dict[str, Tensor]
Initialize the parameters of the DCA model. The bias terms are initialized from the single-point frequencies 'fi', while the coupling matrix is initialized to zero.
Args:
fi(torch.Tensor): Single-point frequencies of the data.
Returns:
Dict[str, torch.Tensor]:
- "bias" (torch.Tensor): Bias terms.
- "coupling_matrix" (torch.Tensor): Coupling matrix.
function init_chains
init_chains(
num_chains: int,
L: int,
q: int,
device: device,
dtype: dtype = torch.float32,
fi: Optional[Tensor] = None
) → Tensor
Initialize the Markov chains of the DCA model. If 'fi' is provided, the chains are sampled from the profile model, otherwise they are sampled uniformly at random.
Args:
num_chains(int): Number of parallel chains.L(int): Length of the MSA.q(int): Number of values that each residue can assume.device(torch.device): Device where to store the chains.dtype(torch.dtype, optional): Data type of the chains. Defaults to torch.float32.fi(Optional[torch.Tensor], optional): Single-point frequencies. Defaults to None.
Returns:
torch.Tensor: Initialized Markov chains in one-hot encoding format, shape (num_chains, L, q).
function get_mask_save
get_mask_save(L: int, q: int, device: device) → Tensor
Returns the mask to save the upper-triangular part of the coupling matrix.
Args:
L(int): Length of the MSA.q(int): Number of values that each residue can assume.device(torch.device): Device where to store the mask.
Returns:
torch.Tensor: Mask.
function resample_sequences
resample_sequences(data: Tensor, weights: Tensor, nextract: int) → Tensor
Extracts nextract sequences from data with replacement according to the weights.
Args:
data(torch.Tensor): Data array.weights(torch.Tensor): Weights of the sequences.nextract(int): Number of sequences to be extracted.
Returns:
torch.Tensor: Extracted sequences.
function get_device
get_device(device: str, message: bool = True) → device
Returns the device where to store the tensors.
Args:
device(str): Device to be used. Possible values are 'cpu', 'cuda', 'mps'.message(bool, optional): Print the device. Defaults to True.
Returns:
torch.device: Device.
function get_dtype
get_dtype(dtype: str) → dtype
Returns the data type of the tensors.
Args:
dtype(str): Data type. Possible values are 'float32' and 'float64'.
Returns:
torch.dtype: Data type.
This file was automatically generated via lazydocs.