Skip to content

module functional


function one_hot

one_hot(x: Tensor, num_classes: int = -1, dtype: dtype = torch.float32)

A fast one-hot encoding function faster than the PyTorch one working with torch.int32 and returning a float Tensor. Works only for 2D tensors.

Args:

  • x (torch.Tensor): Input tensor to be one-hot encoded.
  • num_classes (int, optional): Number of classes. If -1, the number of classes is inferred from the input tensor. Defaults to -1.
  • dtype (torch.dtype, optional): Data type of the output tensor. Defaults to torch.float32.

Returns:

  • torch.Tensor: One-hot encoded tensor.

This file was automatically generated via lazydocs.