PyTorch Quick Start: Classifying an Image

Posted 2017-02-25 • Last updated 2021-10-21

In this post we’ll classify an image with PyTorch. If you prefer to skip the prose, you can checkout the Jupyter notebook.

Two interesting features of PyTorch are pythonic tensor manipulation that’s similar to numpy and dynamic computational graphs, which handle recurrent neural networks in a more natural way than static computational graphs. A good description of the difference between dynamic and static graphs can be found here.

The most basic thing to do with a deep learning framework is to classify an image with a pre-trained model. This works out of the box with PyTorch.

Head over to pytorch.org for instructions on how to install PyTorch on your machine. Then, install other dependencies, including TorchVision:

pip install torchvision requests

Next, import packages and hardcode URLs.

import io
import requests
from PIL import Image
from torchvision import models
import torchvision.transforms as T
from torch.autograd import Variable

LABELS_URL = 'https://jbencook.s3.amazonaws.com/pytorch-quick-start/labels.json'
IMG_URL = 'https://jbencook.s3.amazonaws.com/pytorch-quick-start/cat.jpg'

The first two imports are for reading labels and an image from the internet. The Image class comes from a package called pillow and is the format for passing images into torchvision. LABELS_URL is a JSON file that maps label indices to English descriptions of the ImageNet classes and IMG_URL can be any image you like. If it’s in one of the 1,000 ImageNet classes this code should correctly classify it.

Now, initialize the model:

squeeze = models.squeezenet1_1(pretrained=True)

This will download the weights for the SqueezeNet model.

Define the pre-processing transform:


normalize = T.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
)
preprocess = T.Compose([
   T.Resize(256),
   T.CenterCrop(224),
   T.ToTensor(),
   normalize
])

The specific set of steps in the image processing transform come from the PyTorch examples repo here and here. Without these, the classifier will not work correctly.

Download the image and create a pillow Image:

response = requests.get(IMG_URL)
img_pil = Image.open(io.BytesIO(response.content))

This is a quick trick for reading images from a URL. You can also read them from disk with Image.open("/path/to/image.jpg"). One cool thing about pillow images is that if you execute a code cell with the object in Jupyter, it will display the image for you.

>>> img_pil
cat image for inference

Preprocess the image:

img_tensor = preprocess(img_pil)
img_tensor.unsqueeze_(0)

Here, we apply the preprocessing transforms from above, then we use .unsqueeze_(0) to add a dimension for the batch. Any method that ends with an underscore happens in place.

Run a forward pass with the neural network:

img_variable = Variable(img_tensor)
fc_out = squeeze(img_variable)

The input to the network needs to be an autograd Variable. We run the forward pass by calling the squeeze model. NOTE: this does not apply the softmax activation function.

Download the labels:

labels = {int(key):value for (key, value)
          in requests.get(LABELS_URL).json().items()}

The requests package will parse JSON for us and return a dictionary. But it’s nice for the keys to be integers since we’re looking for the index of the maximum element in fc_out. After this step, labels will look like this:

>>> labels

{0: 'tench, Tinca tinca',
1: 'goldfish, Carassius auratus',
2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
3: 'tiger shark, Galeocerdo cuvieri',
4: 'hammerhead, hammerhead shark',
5: 'electric ray, crampfish, numbfish, torpedo',
6: 'stingray',
...

Now you can print the label:

>>> print(labels[fc_out.data.numpy().argmax()])

tabby, tabby cat

If you want to improve your knowledge of PyTorch, I recommend the PyTorch Pocket Reference. I get commissions for purchases made through this link. So you can learn more about PyTorch and support the blog at the same time!

Connect

Contact

ben [at] sparrow [dot] dev

Email List