module adabmDCA.training
function update_params
update_params(
fi: Tensor,
fij: Tensor,
pi: Tensor,
pij: Tensor,
params: Dict[str, Tensor],
mask: Tensor,
lr: float
) → Dict[str, Tensor]
Updates the parameters of the model.
Args:
fi(torch.Tensor): Single-point frequencies of the data.fij(torch.Tensor): Two-points frequencies of the data.pi(torch.Tensor): Single-point marginals of the model.pij(torch.Tensor): Two-points marginals of the model.params(Dict[str, torch.Tensor]): Parameters of the model.mask(torch.Tensor): Mask of the interaction graph.lr(float): Learning rate.
Returns:
Dict[str, torch.Tensor]: Updated parameters.
function train_graph
train_graph(
sampler: Callable,
chains: Tensor,
mask: Tensor,
fi_target: Tensor,
fij_target: Tensor,
params: Dict[str, Tensor],
nsweeps: int,
lr: float,
max_epochs: int,
target_pearson: float,
fi_test: Optional[Tensor] = None,
fij_test: Optional[Tensor] = None,
checkpoint: Optional[Checkpoint] = None,
check_slope: bool = False,
log_weights: Optional[Tensor] = None,
progress_bar: bool = True,
*args,
**kwargs
) → Tuple[Tensor, Dict[str, Tensor], Tensor, Dict[str, List[float]]]
Trains the model on a given graph until the target Pearson correlation is reached or the maximum number of epochs is exceeded.
Args:
sampler(Callable): Sampling function.chains(torch.Tensor): Markov chains simulated with the model.mask(torch.Tensor): Mask encoding the sparse graph.fi_target(torch.Tensor): Single-point frequencies of the data.fij_target(torch.Tensor): Two-point frequencies of the data.params(Dict[str, torch.Tensor]): Parameters of the model.nsweeps(int): Number of Gibbs steps for each gradient estimation.lr(float): Learning rate.max_epochs(int): Maximum number of gradient updates to be done.target_pearson(float): Target Pearson coefficient.fi_test(Optional[torch.Tensor], optional): Single-point frequencies of the test data. Defaults to None.fij_test(Optional[torch.Tensor], optional): Two-point frequencies of the test data. Defaults to None.checkpoint(Optional[Checkpoint], optional): Checkpoint class to be used for saving the model. Defaults to None.check_slope(bool, optional): Whether to take into account the slope for the convergence criterion or not. Defaults to False.log_weights(Optional[torch.Tensor], optional): Log-weights used for the online computation of the log-likelihood. Defaults to None.progress_bar(bool, optional): Whether to display a progress bar or not. Defaults to True.
Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor, Dict[str, List[float]]]: Updated chains and parameters, log-weights for the log-likelihood computation.
function train_eaDCA
train_eaDCA(
sampler: Callable,
fi_target: Tensor,
fij_target: Tensor,
params: Dict[str, Tensor],
mask: Tensor,
chains: Tensor,
log_weights: Tensor,
target_pearson: float,
nsweeps: int,
max_epochs: int,
pseudo_count: float,
lr: float,
factivate: float,
gsteps: int,
fi_test: Optional[Tensor] = None,
fij_test: Optional[Tensor] = None,
checkpoint: Optional[Checkpoint] = None,
*args,
**kwargs
) → Tuple[Tensor, Dict[str, Tensor], Tensor, Dict[str, List[float]]]
Fits an eaDCA model on the training data and saves the results in a file.
Args:
sampler(Callable): Sampling function to be used.fi_target(torch.Tensor): Single-point frequencies of the data.fij_target(torch.Tensor): Two-point frequencies of the data.params(Dict[str, torch.Tensor]): Initialization of the model's parameters.mask(torch.Tensor): Initialization of the coupling matrix's mask.chains(torch.Tensor): Initialization of the Markov chains.log_weights(torch.Tensor): Log-weights of the chains. Used to estimate the log-likelihood.target_pearson(float): Pearson correlation coefficient on the two-points statistics to be reached.nsweeps(int): Number of Monte Carlo steps to update the state of the model.max_epochs(int): Maximum number of epochs to be performed.pseudo_count(float): Pseudo count for the single and two points statistics. Acts as a regularization.lr(float): Learning rate.factivate(float): Fraction of inactive couplings to activate at each step.gsteps(int): Number of gradient updates to be performed on a given graph.fi_test(Optional[torch.Tensor], optional): Single-point frequencies of the test data. Defaults to None.fij_test(Optional[torch.Tensor], optional): Two-point frequencies of the test data. Defaults to None.checkpoint(Optional[Checkpoint], optional): Checkpoint class to be used to save the model. Defaults to None.
Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor, Dict[str, List[float]]]: Updated chains and parameters, log-weights for the log-likelihood computation, and training history.
function train_edDCA
train_edDCA(
sampler: Callable,
chains: Tensor,
log_weights: Tensor,
fi_target: Tensor,
fij_target: Tensor,
params: Dict[str, Tensor],
mask: Tensor,
lr: float,
nsweeps: int,
target_pearson: float,
target_density: float,
drate: float,
checkpoint: Optional[Checkpoint] = None,
fi_test: Optional[Tensor] = None,
fij_test: Optional[Tensor] = None,
*args,
**kwargs
) → Tuple[Tensor, Dict[str, Tensor], Tensor, Dict[str, List[float]]]
Fits an edDCA model on the training data and saves the results in a file.
Args:
sampler(Callable): Sampling function to be used.chains(torch.Tensor): Initialization of the Markov chains.log_weights(torch.Tensor): Log-weights of the chains. Used to estimate the log-likelihood.fi_target(torch.Tensor): Single-point frequencies of the data.fij_target(torch.Tensor): Two-point frequencies of the data.params(Dict[str, torch.Tensor]): Initialization of the model's parameters.mask(torch.Tensor): Initialization of the coupling matrix's mask.lr(float): Learning rate.nsweeps(int): Number of Monte Carlo steps to update the state of the model.target_pearson(float): Pearson correlation coefficient on the two-points statistics to be reached.target_density(float): Target density of the coupling matrix.drate(float): Percentage of active couplings to be pruned at each decimation step.checkpoint(Optional[Checkpoint], optional): Checkpoint class to be used to save the model. Defaults to None.fi_test(Optional[torch.Tensor], optional): Single-point frequencies of the test data. Defaults to None.fij_test(Optional[torch.Tensor], optional): Two-point frequencies of the test data. Defaults to None.
Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor, Dict[str, List[float]]]: Updated chains and parameters, log-weights for the log-likelihood computation, and training history.
This file was automatically generated via lazydocs.