PyTorch DataLoader Quick Start

Posted 2021-10-07 • Last updated 2021-10-18

PyTorch comes with powerful data loading capabilities out of the box. But with great power comes great responsibility and that makes data loading in PyTorch a fairly advanced topic.

One of the best ways to learn advanced topics is to start with the happy path. Then add complexity when you find out you need it. Let’s run through a quick start example.

What is a PyTorch DataLoader?

The PyTorch DataLoader class gives you an iterable over a Dataset. It’s useful because it can parallelize data loading and automatically shuffle and batch individual samples, all out of the box. This sets you up for a very simple training loop.

PyTorch Dataset

But to create a DataLoader, you have to start with a Dataset, the class responsible for actually reading samples into memory. When you’re implementing a DataLoader, the Dataset is where almost all of the interesting logic will go.

There are two styles of Dataset class, map-style and iterable-style. Map-style Datasets are more common and more straightforward so we’ll focus on them but you can read more about iterable-style datasets in the docs.

To create a map-style Dataset class, you need to implement two methods: __getitem__() and __len__(). The __len__() method returns the total number of samples in the dataset and the __getitem__() method takes an index and returns the sample at that index.

PyTorch Dataset objects are very flexible — they can return any kind of tensor(s) you want. But supervised training datasets should usually return an input tensor and a label. For illustration purposes, let’s create a dataset where the input tensor is a 3×3 matrix with the index along the diagonal. The label will be the index.

It should look like this:

dataset[3]

# Expected result
# {'x': array([[3., 0., 0.],
#         [0., 3., 0.],
#         [0., 0., 3.]]),
#  'y': 3}

Remember, all we have to implement are __getitem__() and __len__():

from typing import Dict, Union

import numpy as np
import torch

class ToyDataset(torch.utils.data.Dataset):
    def __init__(self, size: int):
        self.size = size

    def __len__(self) -> int:
        return self.size

    def __getitem__(self, index: int) -> Dict[str, Union[int, np.ndarray]]:
        return dict(
            x=np.eye(3) * index,
            y=index,
        )

Very simple. We can instantiate the class and start accessing individual samples:

dataset = ToyDataset(10)
dataset[3]

# Expected result
# {'x': array([[3., 0., 0.],
#         [0., 3., 0.],
#         [0., 0., 3.]]),
#  'y': 3}

If you happen to be working with image data, __getitem__() may be a good place to put your TorchVision transforms.

At this point, a sample is a dict with "x" as a matrix with shape (3, 3) and "y" as a Python integer. But what we want are batches of data. "x" should be a PyTorch tensor with shape (batch_size, 3, 3) and "y" should be a tensor with shape batch_size. This is where DataLoader comes back in.

PyTorch DataLoader

To iterate through batches of samples, pass your Dataset object to a DataLoader:

torch.manual_seed(1234)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=3,
    shuffle=True,
    num_workers=2,
)
for batch in loader:
    print(batch["x"].shape, batch["y"])

# Expected result
# torch.Size([3, 3, 3]) tensor([2, 1, 3])
# torch.Size([3, 3, 3]) tensor([6, 7, 9])
# torch.Size([3, 3, 3]) tensor([5, 4, 8])
# torch.Size([1, 3, 3]) tensor([0])

Notice a few things that are happening here:

  • Both the NumPy arrays and Python integers are both getting converted to PyTorch tensors.
  • Although we’re fetching individual samples in ToyDataset, the DataLoader is automatically batching them for us, with the batch size we request. This works even though the individual samples are in dict structures. This also works if you return tuples.
  • The samples are randomly shuffled. We maintain reproducibility by setting torch.manual_seed(1234).
  • The samples are read in parallel across processes. In fact, this code will fail if you run it in a Jupyter notebook. To get it to work, you need to put it underneath a if __name__ == "__main__": check in a Python script.

There’s one other thing that I’m not doing in this sample but you should be aware of. If you need to use your tensors on a GPU (and you probably are for non-trivial PyTorch problems), then you should set pin_memory=True in the DataLoader. This will speed things up by letting the DataLoader allocate space in page-locked memory. You can read more about it here.

Summary

To review: the interesting part of custom PyTorch data loaders is the Dataset class you implement. From there, you get lots of nice features to simplify your data loop. If you need something more advanced, like custom batching logic, check out the API docs. Happy training!

If you want to improve your knowledge of PyTorch, I recommend the PyTorch Pocket Reference. I get commissions for purchases made through this link. So you can learn more about PyTorch and support the blog at the same time!

Connect

Contact

ben [at] sparrow [dot] dev

Email List