Skip to content

module adabmDCA.training


function update_params

update_params(
    fi: Tensor,
    fij: Tensor,
    pi: Tensor,
    pij: Tensor,
    params: Dict[str, Tensor],
    mask: Tensor,
    lr: float
) → Dict[str, Tensor]

Updates the parameters of the model.

Args:

  • fi (torch.Tensor): Single-point frequencies of the data.
  • fij (torch.Tensor): Two-points frequencies of the data.
  • pi (torch.Tensor): Single-point marginals of the model.
  • pij (torch.Tensor): Two-points marginals of the model.
  • params (Dict[str, torch.Tensor]): Parameters of the model.
  • mask (torch.Tensor): Mask of the interaction graph.
  • lr (float): Learning rate.

Returns:

  • Dict[str, torch.Tensor]: Updated parameters.

function train_graph

train_graph(
    sampler: Callable,
    chains: Tensor,
    mask: Tensor,
    fi_target: Tensor,
    fij_target: Tensor,
    params: Dict[str, Tensor],
    nsweeps: int,
    lr: float,
    max_epochs: int,
    target_pearson: float,
    fi_test: Optional[Tensor] = None,
    fij_test: Optional[Tensor] = None,
    checkpoint: Optional[Checkpoint] = None,
    check_slope: bool = False,
    log_weights: Optional[Tensor] = None,
    progress_bar: bool = True,
    *args,
    **kwargs
) → Tuple[Tensor, Dict[str, Tensor], Tensor, Dict[str, List[float]]]

Trains the model on a given graph until the target Pearson correlation is reached or the maximum number of epochs is exceeded.

Args:

  • sampler (Callable): Sampling function.
  • chains (torch.Tensor): Markov chains simulated with the model.
  • mask (torch.Tensor): Mask encoding the sparse graph.
  • fi_target (torch.Tensor): Single-point frequencies of the data.
  • fij_target (torch.Tensor): Two-point frequencies of the data.
  • params (Dict[str, torch.Tensor]): Parameters of the model.
  • nsweeps (int): Number of Gibbs steps for each gradient estimation.
  • lr (float): Learning rate.
  • max_epochs (int): Maximum number of gradient updates to be done.
  • target_pearson (float): Target Pearson coefficient.
  • fi_test (Optional[torch.Tensor], optional): Single-point frequencies of the test data. Defaults to None.
  • fij_test (Optional[torch.Tensor], optional): Two-point frequencies of the test data. Defaults to None.
  • checkpoint (Optional[Checkpoint], optional): Checkpoint class to be used for saving the model. Defaults to None.
  • check_slope (bool, optional): Whether to take into account the slope for the convergence criterion or not. Defaults to False.
  • log_weights (Optional[torch.Tensor], optional): Log-weights used for the online computation of the log-likelihood. Defaults to None.
  • progress_bar (bool, optional): Whether to display a progress bar or not. Defaults to True.

Returns:

  • Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor, Dict[str, List[float]]]: Updated chains and parameters, log-weights for the log-likelihood computation.

function train_eaDCA

train_eaDCA(
    sampler: Callable,
    fi_target: Tensor,
    fij_target: Tensor,
    params: Dict[str, Tensor],
    mask: Tensor,
    chains: Tensor,
    log_weights: Tensor,
    target_pearson: float,
    nsweeps: int,
    max_epochs: int,
    pseudo_count: float,
    lr: float,
    factivate: float,
    gsteps: int,
    fi_test: Optional[Tensor] = None,
    fij_test: Optional[Tensor] = None,
    checkpoint: Optional[Checkpoint] = None,
    *args,
    **kwargs
) → Tuple[Tensor, Dict[str, Tensor], Tensor, Dict[str, List[float]]]

Fits an eaDCA model on the training data and saves the results in a file.

Args:

  • sampler (Callable): Sampling function to be used.
  • fi_target (torch.Tensor): Single-point frequencies of the data.
  • fij_target (torch.Tensor): Two-point frequencies of the data.
  • params (Dict[str, torch.Tensor]): Initialization of the model's parameters.
  • mask (torch.Tensor): Initialization of the coupling matrix's mask.
  • chains (torch.Tensor): Initialization of the Markov chains.
  • log_weights (torch.Tensor): Log-weights of the chains. Used to estimate the log-likelihood.
  • target_pearson (float): Pearson correlation coefficient on the two-points statistics to be reached.
  • nsweeps (int): Number of Monte Carlo steps to update the state of the model.
  • max_epochs (int): Maximum number of epochs to be performed.
  • pseudo_count (float): Pseudo count for the single and two points statistics. Acts as a regularization.
  • lr (float): Learning rate.
  • factivate (float): Fraction of inactive couplings to activate at each step.
  • gsteps (int): Number of gradient updates to be performed on a given graph.
  • fi_test (Optional[torch.Tensor], optional): Single-point frequencies of the test data. Defaults to None.
  • fij_test (Optional[torch.Tensor], optional): Two-point frequencies of the test data. Defaults to None.
  • checkpoint (Optional[Checkpoint], optional): Checkpoint class to be used to save the model. Defaults to None.

Returns:

  • Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor, Dict[str, List[float]]]: Updated chains and parameters, log-weights for the log-likelihood computation, and training history.

function train_edDCA

train_edDCA(
    sampler: Callable,
    chains: Tensor,
    log_weights: Tensor,
    fi_target: Tensor,
    fij_target: Tensor,
    params: Dict[str, Tensor],
    mask: Tensor,
    lr: float,
    nsweeps: int,
    target_pearson: float,
    target_density: float,
    drate: float,
    checkpoint: Optional[Checkpoint] = None,
    fi_test: Optional[Tensor] = None,
    fij_test: Optional[Tensor] = None,
    *args,
    **kwargs
) → Tuple[Tensor, Dict[str, Tensor], Tensor, Dict[str, List[float]]]

Fits an edDCA model on the training data and saves the results in a file.

Args:

  • sampler (Callable): Sampling function to be used.
  • chains (torch.Tensor): Initialization of the Markov chains.
  • log_weights (torch.Tensor): Log-weights of the chains. Used to estimate the log-likelihood.
  • fi_target (torch.Tensor): Single-point frequencies of the data.
  • fij_target (torch.Tensor): Two-point frequencies of the data.
  • params (Dict[str, torch.Tensor]): Initialization of the model's parameters.
  • mask (torch.Tensor): Initialization of the coupling matrix's mask.
  • lr (float): Learning rate.
  • nsweeps (int): Number of Monte Carlo steps to update the state of the model.
  • target_pearson (float): Pearson correlation coefficient on the two-points statistics to be reached.
  • target_density (float): Target density of the coupling matrix.
  • drate (float): Percentage of active couplings to be pruned at each decimation step.
  • checkpoint (Optional[Checkpoint], optional): Checkpoint class to be used to save the model. Defaults to None.
  • fi_test (Optional[torch.Tensor], optional): Single-point frequencies of the test data. Defaults to None.
  • fij_test (Optional[torch.Tensor], optional): Two-point frequencies of the test data. Defaults to None.

Returns:

  • Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor, Dict[str, List[float]]]: Updated chains and parameters, log-weights for the log-likelihood computation, and training history.

This file was automatically generated via lazydocs.