module sampling
function gibbs_sampling
gibbs_sampling(
chains: Tensor,
params: Dict[str, Tensor],
nsweeps: int,
beta: float = 1.0
) → Tensor
Gibbs sampling.
Args:
chains
(torch.Tensor): Initial chains.params
(Dict[str, torch.Tensor]): Parameters of the model.nsweeps
(int): Number of sweeps.beta
(float, optional): Inverse temperature. Defaults to 1.0.
Returns:
torch.Tensor
: Updated chains.
function metropolis
metropolis(
chains: Tensor,
params: Dict[str, Tensor],
nsweeps: int,
beta: float = 1.0
) → Tensor
Metropolis sampling.
Args:
chains
(torch.Tensor): One-hot encoded sequences.params
(Dict[str, torch.Tensor]): Parameters of the model.nsweeps
(int): Number of sweeps to be performed.beta
(float, optional): Inverse temperature. Defaults to 1.0.
Returns:
torch.Tensor
: Updated chains.
function get_sampler
get_sampler(sampling_method: str) → Callable
Returns the sampling function corresponding to the chosen method.
Args:
sampling_method
(str): String indicating the sampling method. Choose between 'metropolis' and 'gibbs'.
Raises:
KeyError
: Unknown sampling method.
Returns:
Callable
: Sampling function.
This file was automatically generated via lazydocs.