PyTorch One Hot Encoding

Ben Cook • Posted 2021-02-02 • Last updated 2021-12-13

PyTorch has a one_hot() function for converting class indices to one-hot encoded targets:

import torch
import torch.nn.functional as F

x = torch.tensor([4, 3, 2, 1, 0])
F.one_hot(x, num_classes=6)

# Expected result
# tensor([[0, 0, 0, 0, 1, 0],
#         [0, 0, 0, 1, 0, 0],
#         [0, 0, 1, 0, 0, 0],
#         [0, 1, 0, 0, 0, 0],
#         [1, 0, 0, 0, 0, 0]])

If you don’t pass the num_classes argument in, one_hot() will infer the number of classes to be the largest class index plus one.

If you have more than one dimension in your class index tensor, one_hot() will encode labels along the last axis:

x = torch.tensor([
    [1, 2],
    [3, 4],
])
F.one_hot(x)

# Expected result
# tensor([[[0, 1, 0, 0, 0],
#          [0, 0, 1, 0, 0]],
#
#         [[0, 0, 0, 1, 0],
#          [0, 0, 0, 0, 1]]])

If you want to reverse the operation, transforming a tensor from one-hot encoding to class indices, use the .argmax() method over the last index:

x = torch.tensor([4, 3, 2, 1, 0])
y = F.one_hot(x, num_classes=6)
y.argmax(-1)

# Expected result
# tensor([4, 3, 2, 1, 0])

One hot encoding is a good trick to be aware of in PyTorch, but it’s important to know that you don’t actually need this if you’re building a classifier with cross entropy loss. In that case, just pass the class index targets into the loss function and PyTorch will take care of the rest.