Skip to content

module adabmDCA.checkpoint


class Checkpoint

Helper class to save the model's parameters and chains at regular intervals during training and to log the progress of the training.

method __init__

__init__(
    file_paths: dict,
    tokens: str,
    args: dict,
    params: Optional[Dict[str, Tensor]] = None,
    chains: Optional[Tensor] = None,
    use_wandb: bool = False
)

Initializes the Checkpoint class.

Args:

  • file_paths (dict): Dictionary containing the paths of the files to be saved.
  • tokens (str): Alphabet to be used for encoding the sequences.
  • args (dict): Dictionary containing the arguments of the training.
  • params (Optional[Dict[str, torch.Tensor]], optional): Parameters of the model. Defaults to None.
  • chains (Optional[torch.Tensor], optional): Chains. Defaults to None.
  • use_wandb (bool, optional): Whether to use Weights & Biases for logging. Defaults to False.

method check

check(updates: int) → bool

Checks if a checkpoint has been reached.

Args:

  • updates (int): Number of gradient updates performed.

Returns:

  • bool: Whether a checkpoint has been reached.

method log

log(record: Dict[str, Any]) → None

Adds a key-value pair to the log dictionary

Args:

  • record (Dict[str, Any]): Key-value pairs to be added to the log dictionary.

method save

save(
    params: Dict[str, Tensor],
    mask: Tensor,
    chains: Tensor,
    log_weights: Tensor
) → None

Saves the chains and the parameters of the model.

Args:

  • params (Dict[str, torch.Tensor]): Parameters of the model.
  • mask (torch.Tensor): Mask of the model's coupling matrix representing the interaction graph
  • chains (torch.Tensor): Chains.
  • log_weights (torch.Tensor): Log of the chain weights. Used for AIS.

This file was automatically generated via lazydocs.