Skip to content

module adabmDCA.functional


function one_hot

one_hot(x: Tensor, num_classes: int = -1, dtype: dtype = torch.float32) → Tensor

A fast one-hot encoding function faster than the PyTorch one working with torch.int32 and returning a float Tensor. Works for both 1D (single sequence) and 2D (batch of sequences) tensors.

Args:

  • x (torch.Tensor): Input tensor to be one-hot encoded. Shape (L,) or (batch_size, L).
  • num_classes (int, optional): Number of classes. If -1, the number of classes is inferred from the input tensor. Defaults to -1.
  • dtype (torch.dtype, optional): Data type of the output tensor. Defaults to torch.float32.

Returns:

  • torch.Tensor: One-hot encoded tensor. Shape (L, num_classes) for 1D input or (batch_size, L, num_classes) for 2D input.

This file was automatically generated via lazydocs.