module adabmDCA.checkpoint
class Checkpoint
Helper class to save the model's parameters and chains at regular intervals during training and to log the progress of the training.
method __init__
__init__(file_paths: dict, tokens: str, args: dict, use_wandb: bool = False)
Initializes the Checkpoint class.
Args:
file_paths(dict): Dictionary containing the paths of the files to be saved.tokens(str): Alphabet to be used for encoding the sequences.args(dict): Dictionary containing the arguments of the training.use_wandb(bool, optional): Whether to use Weights & Biases for logging. Defaults to False.
method check
check(updates: int) → bool
Checks if a checkpoint has been reached.
Args:
updates(int): Number of gradient updates performed.
Returns:
bool: Whether a checkpoint has been reached.
method log
log(record: Dict[str, Any]) → None
Adds a key-value pair to the log dictionary
Args:
record(Dict[str, Any]): Key-value pairs to be added to the log dictionary.
method save
save(
params: Dict[str, Tensor],
mask: Tensor,
chains: Tensor,
log_weights: Tensor
) → None
Saves the chains and the parameters of the model.
Args:
params(Dict[str, torch.Tensor]): Parameters of the model.mask(torch.Tensor): Mask of the model's coupling matrix representing the interaction graph.chains(torch.Tensor): Chains.log_weights(torch.Tensor): Log of the chain weights. Used for AIS.
This file was automatically generated via lazydocs.