Binary Cross Entropy Explained

Ben Cook • Posted 2021-02-22 • Last updated 2021-10-21

The most common loss function for training a binary classifier is binary cross entropy (sometimes called log loss). You can implement it in NumPy as a one-liner:

def binary_cross_entropy(yhat: np.ndarray, y: np.ndarray) -> float:
    """Compute binary cross-entropy loss for a vector of predictions

    Parameters
    ----------
    yhat
        An array with len(yhat) predictions between [0, 1]
    y
        An array with len(y) labels where each is one of {0, 1}
    """
    return -(y * np.log(yhat) + (1 - y) * np.log(1 - yhat)).mean()

Why does this work?

Good question! The motivation for this loss function comes from information theory. We’re trying to minimize the difference between the y and yhat distributions. That is, we want to minimize the difference between ground truth labels and model predictions. This is an elegant solution for training machine learning models, but the intuition is even simpler than that.

Binary classifiers, such as logistic regression, predict yes/no target variables that are typically encoded as 1 (for yes) or 0 (for no). When the model produces a floating point number between 0 and 1 (yhat in the function above), you can often interpret that as p(y == 1) or the probability that the true answer for that record is “yes”. The data you use to train the algorithm will have labels that are either 0 or 1 (y in the function above), since the answer for each record in your training data is known.

To train a good model, you want to penalize predictions that are far away from their ground truth values. That means you want to penalize values close to 0 when the label is 1 and you want to penalize values close to 1 when the label is 0.

The y and (1 - y) terms act like switches so that np.log(yhat) is added when the true answer is “yes” and np.log(1 - yhat) is added when the true answer is “no”. That would move the loss in the opposite direction that we want (since, for example, np.log(yhat) is larger when yhat is closer to 1 than 0) so we take the negative of the sum instead of the sum itself.

Here’s a plot with the first and second log terms (respectively) when they’re switched on:

binary cross entropy terms

Notice the log function increasingly penalizes values as they approach the wrong end of the range.

A couple other things to watch out for:

  • Since we’re taking np.log(yhat) and np.log(1 - yhat), we can’t use a model that predicts 0 or 1 for yhat. This is because np.log(0) is -inf. For this reason, we typically apply the sigmoid activation function to raw model outputs. This allows values to get close to 0 or 1, but never actually reach the extremes of the range.
  • We typically divide by the number of records so the value is normalized and comparable across datasets with different sizes. This is the purpose of the .mean() method call in the implementation above.

In practice

Of course, you probably don’t need to implement binary cross entropy yourself. The loss function comes out of the box in PyTorch and TensorFlow. When you use the loss function in these deep learning frameworks, you get automatic differentiation so you can easily learn weights that minimize the loss. You can also use the same loss function in scikit-learn.