The PyTorch Softmax Function

Ben Cook • Posted 2021-01-29 • Last updated 2021-10-14

The softmax activation function is a common way to encode categorical targets in many machine learning algorithms. The easiest way to use this activation function in PyTorch is to call the top-level torch.softmax() function. Here’s an example:

import torch

x = torch.randn(2, 3, 4)
y = torch.softmax(x, dim=-1)

The dim argument is required unless your input tensor is a vector. It specifies the axis along which to apply the softmax activation. Passing in dim=-1 applies softmax to the last dimension. So, after you do this, the elements of the last dimension will sum to 1.

You can prove this to yourself by reducing the last dimension with the sum() method on the result tensor:

y.sum(dim=-1)

# Expected result
# tensor([[1.0000, 1.0000, 1.0000],
#         [1.0000, 1.0000, 1.0000]])

Another way you will see softmax used in the PyTorch docs is with the Softmax class:

import torch.nn as nn

softmax = nn.Softmax(dim=-1)
y = softmax(x)

As far as I can tell, the only advantage to using it this way is if you want to treat softmax as its own layer for the sake of code clarity. But generally, I prefer the top-level function for simplicity.