TorchVision Datasets: Getting Started

Ben Cook • Posted 2021-10-22 • Code

The TorchVision datasets subpackage is a convenient utility for accessing well-known public image and video datasets. You can use these tools to start training new computer vision models very quickly.

TorchVision Datasets Example

To get started, all you have to do is import one of the Dataset classes. Then, instantiate it and access one of the samples with indexing:

from torchvision import datasets

dataset = datasets.MNIST(root="./", download=True)
img, label = dataset[10]
img.size

# Expected result
# (28, 28)

You’ll get a tuple with a Pillow image and an integer label back:

torchvision datasets mnist 3

The TorchVision datasets implement __len__() and __getitem__() methods, which means that in addition to getting specific elements by index, you can also get the number of samples with the len() function:

len(dataset)

# Expected result
# 60000

Additionally, DataLoader classes can use TorchVision Dataset objects to create automatic batches for training.

Since they mostly return Pillow images, you do need to pass in a transform to convert the image to a tensor:

import torch
from torchvision import transforms

dataset = datasets.MNIST(
    root="./",
    transform=transforms.ToTensor()
)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=4)

x, y = next(iter(data_loader))

x.shape

# Expected result
# torch.Size([4, 1, 28, 28])

API

The interface for the TorchVision Dataset classes is somewhat inconsistent because every dataset has a slightly different set of constraints. For example, many of the datasets return (PIL.Image, int) tuples, but this obviously wouldn’t work for videos (TorchVision packs them into tensors).

But generally speaking, the constructors take the following arguments:

  • root: where to download the raw dataset or where the Dataset class should expect to find a raw dataset that has already been downloaded.
  • split: which holdout to use. This can be train, test, val, extra… best to look at the docs for the dataset you want to use.
  • download: a boolean indicating whether TorchVision should download the raw data for you. Although setting this argument to true will raise an error for datasets like ImageNet. More on this below.
  • transform: a TorchVision transform to apply to the input image or video.

A word about ImageNet

ImageNet is no longer available for small companies or independent researchers. This is a real shame because pre-trained classifiers in model zoos are almost always trained on ImageNet.

However, it is possible to download most of the ImageNet dataset from Academic Torrents. I cannot endorse this strategy because I don’t know if it’s allowed.

If you did want to download the train and validation sets from ImageNet 2012, here are some steps you could follow:

  1. Launch an Amazon Linux EC2 instance with at least 200GB of storage. The whole process takes about 2 hours on a c5.xlarge instance.

2. Install the aria2c command-line tool (instructions here).

3. Download the tar files:

# Download the validation set
aria2c https://academictorrents.com/download/dfa9ab2528ce76b907047aa8cf8fc792852facb9.torrent

# Download the train set
aria2c https://academictorrents.com/download/a306397ccf9c2ead27155983c254227c0fd938e2.torrent

4. Make sure the files match the MD5 hashes (helpfully provided by the TorchVision team):

# Check the validation file
md5sum ILSVRC2012_img_val.tar

# Expected result
# 29b22e2961454d5413ddabcf34fc5622

# Check the train file
md5sum ILSVRC2012_img_train.tar

# Expected result
# 1d675b47d978889d74fa0da5fadfb00e

5. Upload the files to S3 — hosting the files costs a little over $3 per month.

6. Terminate the instance.

7. Send the Academic Torrents team some Bitcoin to say thank you.

Summary

And that’s all you need to know to get started with TorchVision Datasets. For production machine learning pipelines, you probably want to implement your own Dataset class, but the datasets that come out of the box with TorchVision are a great way to experiment quickly!