Normalizing Images in PyTorch

Ben Cook • Posted 2021-01-15 • Last updated 2021-10-21

When you read an image into memory, the pixels usually have 8-bit integers between 0 and 255 for all three channels. But regression models (including neural networks) prefer floating point values within a smaller range. Often, you want values to have a mean of 0 and a standard deviation of 1 like the standard normal distribution.

The Normalize() transform

Doing this transformation is called normalizing your images. In PyTorch, you can normalize your images with torchvision, a utility that provides convenient preprocessing transformations. For each value in an image, torchvision.transforms.Normalize() subtracts the channel mean and divides by the channel standard deviation.

Let’s take a look at how this works. First, load an image into PIL[1]:

resp = requests.get('https://jbencook.s3.amazonaws.com/assets/img/cat.jpg')
img_pil = Image.open(io.BytesIO(resp.content))

img_pil
cat image for inference

Now, take a look at the distribution of pixel values:

plt.hist(np.array(img_pil).ravel(), bins=50, density=True);
plt.xlabel("pixel values")
plt.ylabel("relative frequency")
plt.title("distribution of pixels");
histogram raw pixel distribution

Notice, the values range from 0 to 255. Next, compose the ToTensor() and Normalize() transforms and apply them to the image:

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])
normalized_img = transform(img_pil)

Notice we’re passing in three values for the mean and three values for the standard deviation, one for each channel. The normalized_img result is a PyTorch tensor. Now, look at the distribution of pixel values for the normalized image:

plt.hist(normalized_img.numpy().ravel(), bins=30, density=True)
plt.xlabel("pixel values")
plt.ylabel("relative frequency")
plt.title("distribution of pixels");
normalized pixel distribution

The normalized values are roughly bounded by [-2, 2]. OK, so what’s going on here? The way to normalize data is to subtract the mean and divide by the standard deviation. But the mean of the original pixel values is clearly not anywhere close to the values [0.485, 0.456, 0.406].

That’s because torchvision splits the work up between ToTensor() and Normalize():

  1. ToTensor() takes a PIL image (or np.int8 NumPy array) with shape (n_rows, n_cols, n_channels) as input and returns a PyTorch tensor with floats between 0 and 1 and shape (n_channels, n_rows, n_cols).
  2. Normalize() subtracts the mean and divides by the standard deviation of the floating point values in the range [0, 1].

That means you need to know the mean and standard deviation of the floats, not the original pixels.

An alternative

If this is going to confuse you, you can code the transformations yourself:

MEAN = 255 * torch.tensor([0.485, 0.456, 0.406])
STD = 255 * torch.tensor([0.229, 0.224, 0.225])

x = torch.from_numpy(np.array(img_pil))
x = x.type(torch.float32)
x = x.permute(-1, 0, 1)
x = (x - MEAN[:, None, None]) / STD[:, None, None]

You can use the same mean and standard deviation as before, but scale them to original pixel ranges. To get the right tensor you need to:

  1. Convert the PIL image into a PyTorch tensor.
  2. Cast the int8 values to float32.
  3. Rearrange the axes so that channels come first.
  4. Subtract the mean and divide by the standard deviation.

Note: you have to add dimensions to the mean and standard deviation for the broadcasting to work.

You can prove that this recipe produces the same result as torchvision by computing the relative error between x and normalized_img from above:

torch.norm(x - normalized_img) / torch.norm(normalized_img)

# Expected result
# tensor(6.4971e-08)

One advantage of doing this yourself is that the transformation becomes more explicit. Additionally, it gives you more flexibility. If you want values between -1 and 1 instead of having 0 mean and standard deviation of 1, you can do that:

x = torch.from_numpy(np.array(img_pil))
x = x.type(torch.float32)
x = x.permute(-1, 0, 1)
x = 2 * x / 255 - 1

x.min(), x.max()

# Expected result
# (tensor(-1.), tensor(1.))

By the way, if you want to normalize all your images in a training or inference setup, you will probably want these transformations to live in the __getitem__() method of torch.utils.data.Dataset. That way it will be applied automatically to every image you access. This works fine with the torchvision transforms or with your own code.

Reversing normalization

Matplotlib can display images with float values between [0, 1] or pixel values between [0, 255]. You need to rearrange the channel dimension back, but even still, the standard normal pixel values don’t display very well:

plt.imshow(np.array(normalized_img).transpose(1, 2, 0))
plt.xticks([])
plt.yticks([]);

# Expected warning
# Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
normalized cat image pytorch

If you want to reverse the normalize transformation so that you can visualize images you’re loading into PyTorch, you’ll need to reverse the operations yourself. This involves multiplying by the standard deviation and adding the mean:

MEAN = torch.tensor([0.485, 0.456, 0.406])
STD = torch.tensor([0.229, 0.224, 0.225])

x = normalized_img * STD[:, None, None] + MEAN[:, None, None]
plt.imshow(x.numpy().transpose(1, 2, 0))
plt.xticks([])
plt.yticks([]);
cat image for inference

Voila! Now you can generate RGB images that were previously normalized by PyTorch.

Notes

[1]: Checkout the Jupyter notebook if you want to see all the import statements.