There are three cases where you might want to use a cross entropy loss function:
- You have a single-label binary target
- You have a single-label categorical target
- You have a multi-label categorical target
You can use binary cross entropy for single-label binary targets and multi-label categorical targets (because it treats multi-label 0/1 indicator variables the same as single-label one hot vectors). You can use categorical cross entropy for single-label categorical targets.
But there are a few things that make it a little weird to figure out which PyTorch loss you should reach for in the above cases.
Why it’s confusing
- The naming conventions are different. The loss classes for binary and categorical cross entropy loss are
BCELoss
andCrossEntropyLoss
, respectively. It’s not a huge deal, but Keras uses the same pattern for both functions (BinaryCrossentropy
andCategoricalCrossentropy
), which is a little nicer for tab complete. - The shapes of the target tensors are different. For binary cross entropy, you pass in two tensors of the same shape. The output tensor should have elements in the range of
[0, 1]
and the target tensor with labels should be dummy indicators with 0 for false and 1 for true (in this case both the output and target tensors should be floats). For categorical cross entropy, the target is a one-dimensional tensor of class indices with type long and the output should have raw, unnormalized values. That brings me to the third reason why cross entropy is confusing. - The non-linear activation is automatically applied in
CrossEntropyLoss
. When you callBCELoss
, you will typically want to apply the sigmoid activation function to the outputs before computing the loss to ensure the values are in the range[0, 1]
. For single-label categorical outputs, you also usually want the softmax activation function to be applied, but PyTorch applies this automatically for you. Note: you can match this behavior in binary cross entropy by using theBCEWithLogitsLoss
.
Example
Here’s an example of the different kinds of cross entropy loss functions you can use as a cheat sheet:
import torch
import torch.nn as nn
# Single-label binary
x = torch.randn(10)
yhat = torch.sigmoid(x)
y = torch.randint(2, (10,), dtype=torch.float)
loss = nn.BCELoss()(yhat, y)
# Single-label binary with automatic sigmoid
loss = nn.BCEWithLogitsLoss()(x, y)
# Single-label categorical
x = torch.randn(10, 5)
y = torch.randint(5, (10,))
loss = nn.CrossEntropyLoss()(x, y)
# Multi-label categorical
x = torch.randn(10, 5)
yhat = torch.sigmoid(x)
y = torch.randint(2, (10, 5), dtype=torch.float)
loss = nn.BCELoss()(yhat, y)
# Multi-label categorical with automatic sigmoid
loss = nn.BCEWithLogitsLoss()(x, y)