module adabmDCA.checkpoint
class Checkpoint
Helper class to save the model's parameters and chains at regular intervals during training and to log the progress of the training.
method __init__
__init__(
file_paths: dict,
tokens: str,
args: dict,
params: Optional[Dict[str, Tensor]] = None,
chains: Optional[Tensor] = None,
use_wandb: bool = False
)
Initializes the Checkpoint class.
Args:
file_paths(dict): Dictionary containing the paths of the files to be saved.tokens(str): Alphabet to be used for encoding the sequences.args(dict): Dictionary containing the arguments of the training.params(Optional[Dict[str, torch.Tensor]], optional): Parameters of the model. Defaults to None.chains(Optional[torch.Tensor], optional): Chains. Defaults to None.use_wandb(bool, optional): Whether to use Weights & Biases for logging. Defaults to False.
method check
check(updates: int) → bool
Checks if a checkpoint has been reached.
Args:
updates(int): Number of gradient updates performed.
Returns:
bool: Whether a checkpoint has been reached.
method log
log(record: Dict[str, Any]) → None
Adds a key-value pair to the log dictionary
Args:
record(Dict[str, Any]): Key-value pairs to be added to the log dictionary.
method save
save(
params: Dict[str, Tensor],
mask: Tensor,
chains: Tensor,
log_weights: Tensor
) → None
Saves the chains and the parameters of the model.
Args:
params(Dict[str, torch.Tensor]): Parameters of the model.mask(torch.Tensor): Mask of the model's coupling matrix representing the interaction graphchains(torch.Tensor): Chains.log_weights(torch.Tensor): Log of the chain weights. Used for AIS.
This file was automatically generated via lazydocs.