Skip to main content

Intermediate Computer Vision: Episode 2

In this lesson, you will explore how to use the data utilities in PyTorch to efficiently load your data. You can find the corresponding Jupyter notebook here. If you already know about PyTorch, you may want to skip ahead to episode 4 when we start modeling.

You will use the HaGRID dataset from the previous episode to create a custom and corresponding to feed the data to a model. The end result will be a custom class called GestureDataset that we can use to ensure reliable data pipelines in the remainder of this tutorial.


None of the patterns you will learn in this episode are unique to this example or to image data, so you will be able to adapt these lessons to work with any dataset you want to model with PyTorch.

1Why use a Torch DataLoader?

PyTorch's built-in Dataset and Dataloader objects simplify the processes between ingesting data and feeding it to a model. The objects provide abstractions that address requirements common to most, if not all, deep learning scenarios.

  • The Dataset defines the structure and how to fetch data instances.
  • The Dataloader leverages the Dataset to load batches of data that can easily be shuffled, sampled, transformed, etc.

Importantly for many computer vision cases, this PyTorch functionality is built to scale to training large networks on large datasets and there are many optimization avenues to explore for advanced users.

2What is a Torch DataLoader?

The class helps you efficiently access batches from a dataset so you can feed them into your model. The DataLoader constructor has this signature:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,

You can read more detail here. The most important argument is the dataset, which should be an instance of a object. This object is what we will customize next. Then we can use it to instantiate Dataloader objects that follow the standard pattern for feeding data into a PyTorch model.

3Build a Torch Dataset

To create a Dataloader, we need to pass it a Dataset.
There are two ways to define a Torch Dataset object, the map and the iterable style. The difference is whether the class defines the __len__ and __getitem__ functions (map type) or the __iter__ function (iterable type). You can read more about this distinction here. For now, all you need to know in the rest of this episode you will build a custom dataset with the HaGRID data GestureDataset.

4Example: Components of the GestureDataset

In all remaining notebook examples and flows in this tutorial, we will use the GestureDataset. Much of the code is reused from the original source, which you can view here. The end goal is to create a GestureDataset object that we can easily use in model training code like the following snippet:

model = _initialize_model(model_name, checkpoint_path, device)
train_dataset = GestureDataset(is_train=True, transform=get_transform())
test_dataset = GestureDataset(is_train=False, transform=get_transform())
TrainClassifier.train(model, train_dataset, test_dataset, device)

This section shows how to implement the methods needed to use GestureDataset, or any custom dataset, as depicted in the above code. More than the details of this specific example code, the main takeaway of this section is that when working with a custom Dataset class you need to:

  1. Your class should be a subclass of
  2. You need to define the constructor.
  3. You either need to define the __getitem__ and __len__ methods, or define the __iter__ method. You can put whatever you want in the different methods of your Dataset classes so long as the function signatures follow the PyTorch protocol.

4aThe Dataset Constructor

The Dataset constructor is called upon to create the dataset. For GestureDataset, the constructor does the following:

  • Assign class variables for a configuration file, transformations, and dataset labels.
  • Split the images and their annotations into training and validation sets.
class GestureDataset(

def __init__(self, is_train, conf, transform = None, is_test = False):
self.conf = conf
self.transform = transform
self.is_train = is_train
self.labels = {
label: num for (label, num) in zip(self.conf.dataset.targets, range(len(self.conf.dataset.targets)))
self.leading_hand = {"right": 0, "left": 1}
subset = self.conf.dataset.get("subset", None)
self.annotations = self.__read_annotations(subset)
users = self.annotations["user_id"].unique()
users = sorted(users)
train_users = users[: int(len(users) * 0.8)]
val_users = users[int(len(users) * 0.8) :]
self.annotations = self.annotations.copy()
if not is_test:
if is_train:
self.annotations = self.annotations[self.annotations["user_id"].isin(train_users)]
self.annotations = self.annotations[self.annotations["user_id"].isin(val_users)]


4bGetting a Data Instance

The __getitem__ is a class method that allows instances of the Dataset class to be indexed like a list using []. In our case, we want this function to take an integer index and return an appropriately sized image and its label.

class GestureDataset(


def __getitem__(self, index: int):
row = self.annotations.iloc[[index]].to_dict("records")[0]
image_resized, gesture, leading_hand = self.__prepare_image_target(
row["target"], row["name"], row["bboxes"], row["labels"], row["leading_hand"]
label = {"gesture": self.labels[gesture], "leading_hand": self.leading_hand[leading_hand]}
if self.transform is not None:
image_resized, label = self.transform(image_resized, label)
return image_resized, label


5Example: Using the GestureDataset

In this section, you will use the GestureDataset to instantiate a Dataloader and visualize one batch of images with their labels.

First, we will import dependencies.

import torch
from hagrid.classifier.dataset import GestureDataset
from hagrid.classifier.preprocess import get_transform
from hagrid.classifier.utils import collate_fn
from omegaconf import OmegaConf
from math import sqrt
import matplotlib.pyplot as plt

path_to_config = './hagrid/classifier/config/default.yaml'
conf = OmegaConf.load(path_to_config)

Then we instantiate the GestureDataset implemented here.

train_dataset = GestureDataset(is_train=True, conf=conf, transform=get_transform())

Now, you can use the train_dataset to create a data loader to request batches from.

train_dataloader =
num_workers=1, # change this to load data faster. feasible values depend on your machine specs.
# What happens
# to the image grid displayed by the view_batch function
# when you set shuffle=False in this constructor?

Here is a helper function to show the contents of a batch:

def view_batch(images, labels, batch_size):
import matplotlib.pyplot as plt
grid_dim = (
int(sqrt(batch_size)) + (1 if sqrt(batch_size) % 1 > 0 else 0)
fig, axes = plt.subplots(*grid_dim)
for i, (image, label) in enumerate(zip(images, labels)):
x, y = i//grid_dim[1], i%grid_dim[1]
image = image.permute(1,2,0)
axes[x, y].imshow(image)
axes[x, y].set_title(conf.dataset.targets[label['gesture']], fontsize=10)
[axes[x, y].spines[_dir].set_visible(False) for _dir in ['right', 'left', 'top', 'bottom']]
axes[x, y].set_xticks([])
axes[x, y].set_yticks([])

Now we can take the next batch from the train_dataloader and view a grid of each image and its corresponding label.

images, labels = next(iter(train_dataloader))
view_batch(images, labels, BATCH_SIZE)

Nice! Getting a reliable data flow is a big step in any machine learning project. In this lesson, you have just scratched the surface of the tools PyTorch offers to help you do this. You learned about PyTorch datasets and data loaders in this episode. You saw to use them to efficiently and reliably load HaGRID dataset samples for training PyTorch models. Looking forward you will pair PyTorch data loaders with Metaflow features to extend the concepts when working with datasets in models in the cloud. See you there!