PyTorch has a one_hot()
function for converting class indices to one-hot encoded targets:
import torch
import torch.nn.functional as F
x = torch.tensor([4, 3, 2, 1, 0])
F.one_hot(x, num_classes=6)
# Expected result
# tensor([[0, 0, 0, 0, 1, 0],
# [0, 0, 0, 1, 0, 0],
# [0, 0, 1, 0, 0, 0],
# [0, 1, 0, 0, 0, 0],
# [1, 0, 0, 0, 0, 0]])
If you don’t pass the num_classes
argument in, one_hot()
will infer the number of classes to be the largest class index plus one.
If you have more than one dimension in your class index tensor, one_hot()
will encode labels along the last axis:
x = torch.tensor([
[1, 2],
[3, 4],
])
F.one_hot(x)
# Expected result
# tensor([[[0, 1, 0, 0, 0],
# [0, 0, 1, 0, 0]],
#
# [[0, 0, 0, 1, 0],
# [0, 0, 0, 0, 1]]])
If you want to reverse the operation, transforming a tensor from one-hot encoding to class indices, use the .argmax()
method over the last index:
x = torch.tensor([4, 3, 2, 1, 0])
y = F.one_hot(x, num_classes=6)
y.argmax(-1)
# Expected result
# tensor([4, 3, 2, 1, 0])
One hot encoding is a good trick to be aware of in PyTorch, but it’s important to know that you don’t actually need this if you’re building a classifier with cross entropy loss. In that case, just pass the class index targets into the loss function and PyTorch will take care of the rest.