Skip to content

module sampling


function gibbs_sampling

gibbs_sampling(
    chains: Tensor,
    params: Dict[str, Tensor],
    nsweeps: int,
    beta: float = 1.0
) → Tensor

Gibbs sampling.

Args:

  • chains (torch.Tensor): Initial chains.
  • params (Dict[str, torch.Tensor]): Parameters of the model.
  • nsweeps (int): Number of sweeps.
  • beta (float, optional): Inverse temperature. Defaults to 1.0.

Returns:

  • torch.Tensor: Updated chains.

function metropolis

metropolis(
    chains: Tensor,
    params: Dict[str, Tensor],
    nsweeps: int,
    beta: float = 1.0
) → Tensor

Metropolis sampling.

Args:

  • chains (torch.Tensor): One-hot encoded sequences.
  • params (Dict[str, torch.Tensor]): Parameters of the model.
  • nsweeps (int): Number of sweeps to be performed.
  • beta (float, optional): Inverse temperature. Defaults to 1.0.

Returns:

  • torch.Tensor: Updated chains.

function get_sampler

get_sampler(sampling_method: str) → Callable

Returns the sampling function corresponding to the chosen method.

Args:

  • sampling_method (str): String indicating the sampling method. Choose between 'metropolis' and 'gibbs'.

Raises:

  • KeyError: Unknown sampling method.

Returns:

  • Callable: Sampling function.

This file was automatically generated via lazydocs.