Skip to content

module adabmDCA.checkpoint


class Checkpoint

Helper class to save the model's parameters and chains at regular intervals during training and to log the progress of the training.

method __init__

__init__(file_paths: dict, tokens: str, args: dict, use_wandb: bool = False)

Initializes the Checkpoint class.

Args:

  • file_paths (dict): Dictionary containing the paths of the files to be saved.
  • tokens (str): Alphabet to be used for encoding the sequences.
  • args (dict): Dictionary containing the arguments of the training.
  • use_wandb (bool, optional): Whether to use Weights & Biases for logging. Defaults to False.

method check

check(updates: int) → bool

Checks if a checkpoint has been reached.

Args:

  • updates (int): Number of gradient updates performed.

Returns:

  • bool: Whether a checkpoint has been reached.

method log

log(record: Dict[str, Any]) → None

Adds a key-value pair to the log dictionary

Args:

  • record (Dict[str, Any]): Key-value pairs to be added to the log dictionary.

method save

save(
    params: Dict[str, Tensor],
    mask: Tensor,
    chains: Tensor,
    log_weights: Tensor
) → None

Saves the chains and the parameters of the model.

Args:

  • params (Dict[str, torch.Tensor]): Parameters of the model.
  • mask (torch.Tensor): Mask of the model's coupling matrix representing the interaction graph.
  • chains (torch.Tensor): Chains.
  • log_weights (torch.Tensor): Log of the chain weights. Used for AIS.

This file was automatically generated via lazydocs.