Skip to content

module utils


function set_zerosum_gauge

set_zerosum_gauge(params: Dict[str, Tensor]) → Dict[str, Tensor]

Sets the zero-sum gauge on the coupling matrix.

Args:

  • params (Dict[str, torch.Tensor]): Parameters of the model.

Returns:

  • Dict[str, torch.Tensor]: Parameters with fixed gauge.

function init_parameters

init_parameters(fi: Tensor) → Dict[str, Tensor]

Initialize the parameters of the DCA model.

Args:

  • fi (torch.Tensor): Single-point frequencies of the data.

Returns:

  • Dict[str, torch.Tensor]: Parameters of the model.

function init_chains

init_chains(
    num_chains: int,
    L: int,
    q: int,
    device: device,
    dtype: dtype = torch.float32,
    fi: Tensor | None = None
) → Tensor

Initialize the chains of the DCA model. If 'fi' is provided, the chains are sampled from the profile model, otherwise they are sampled uniformly at random.

Args:

  • num_chains (int): Number of parallel chains.
  • L (int): Length of the MSA.
  • q (int): Number of values that each residue can assume.
  • device (torch.device): Device where to store the chains.
  • dtype (torch.dtype, optional): Data type of the chains. Defaults to torch.float32.
  • fi (torch.Tensor | None, optional): Single-point frequencies. Defaults to None.

Returns:

  • torch.Tensor: Initialized parallel chains in one-hot encoding format.

function get_mask_save

get_mask_save(L: int, q: int, device: device) → Tensor

Returns the mask to save the upper-triangular part of the coupling matrix.

Args:

  • L (int): Length of the MSA.
  • q (int): Number of values that each residue can assume.
  • device (torch.device): Device where to store the mask.

Returns:

  • torch.Tensor: Mask.

function resample_sequences

resample_sequences(data: Tensor, weights: Tensor, nextract: int) → Tensor

Extracts nextract sequences from data with replacement according to the weights.

Args:

  • data (torch.Tensor): Data array.
  • weights (torch.Tensor): Weights of the sequences.
  • nextract (int): Number of sequences to be extracted.

Returns:

  • torch.Tensor: Extracted sequences.

function get_device

get_device(device: str, message: bool = True) → device

Returns the device where to store the tensors.

Args:

  • device (str): Device to be used.
  • message (bool, optional): Print the device. Defaults to True.

Returns:

  • torch.device: Device.

function get_dtype

get_dtype(dtype: str) → dtype

Returns the data type of the tensors.

Args:

  • dtype (str): Data type.

Returns:

  • torch.dtype: Data type.

This file was automatically generated via lazydocs.