module adabmDCA.functional
function one_hot
one_hot(x: Tensor, num_classes: int = -1, dtype: dtype = torch.float32) → Tensor
A fast one-hot encoding function faster than the PyTorch one working with torch.int32 and returning a float Tensor. Works for both 1D (single sequence) and 2D (batch of sequences) tensors.
Args:
x(torch.Tensor): Input tensor to be one-hot encoded. Shape (L,) or (batch_size, L).num_classes(int, optional): Number of classes. If -1, the number of classes is inferred from the input tensor. Defaults to -1.dtype(torch.dtype, optional): Data type of the output tensor. Defaults to torch.float32.
Returns:
torch.Tensor: One-hot encoded tensor. Shape (L, num_classes) for 1D input or (batch_size, L, num_classes) for 2D input.
This file was automatically generated via lazydocs.