Skip to main content

Computer Vision with Metaflow: Intermediate Tutorial

In this tutorial, you will build a set of workflows to train and evaluate a machine learning model that performs image classification. You will use PyTorch and Metaflow to write computer vision code you can use as a foundation for real-world data science projects.

Computer Vision with Metaflow - Intermediate

Try it out directly in your browser

Open in Sandbox
from metaflow import FlowSpec, Parameter, step, batch, environment, S3, metaflow_config, current

class TrainHandGestureClassifier(FlowSpec):

S3_URI = Parameter(
's3', type=str,
help = 'The s3 uri to the root of the model objects.'

DATA_ROOT = Parameter(
'data', type=str, default='data/',
help = 'The relative location of the training data.'

IMAGES = Parameter(
'images', type=str,
default = '',
help = 'The path to the images.'

ANNOTATIONS = Parameter(
'annotations', type=str,
default = ''

PATH_TO_CONFIG = Parameter(
'config', type=str,
default = 'hagrid/classifier/config/default.yaml',
help = 'The path to classifier training config.'

'epochs', type=int, default=100,
help = 'The number of epochs to train the model from.'

MODEL_NAME = Parameter(
'model', type=str,
default = 'MobileNetV3_small',
help = '''Pick a model from:
- [ResNet18, ResNext50, ResNet152, MobileNetV3_small, MobileNetV3_large, Vitb32]

'checkpoint', type=str, default = None,
help = 'Path to the model state you want to resume. Eithe'

# # If you do not plan to checkpoint models in S3, then you may want
# # to use Metaflow's IncludeFile here, instead of this parameter to
# # the path. Make sure to import IncludeFile :)
# CHECKPOINT_PATH = IncludeFile(
# 'best_model.pth',
# is_text=False,
# help='The path to your local best_model.pth checkpoint',
# default='./best_model.pth'
# )

def start(self):
# Configure the (remote) experiment tracking location.
# In this tutorial, experiment tracking means
# 1: Storing the best model state checkpoints to S3.
# 2: Storing parameters as Metaflow artifacts.
# 3: Storing metrics/logs with Tensorboard.
import os
print("Training {} in flow {}".format(self.MODEL_NAME, current.flow_name))
self.datastore = metaflow_config.METAFLOW_CONFIG['METAFLOW_DATASTORE_SYSROOT_S3']
self.experiment_storage_prefix = os.path.join(self.datastore, current.flow_name, current.run_id)

def _download_data_from_s3(self, file, sample : bool = True):
import zipfile
import os
with S3(s3root = self.S3_URI) as s3:
if sample:
path = os.path.join(self.DATA_ROOT, file)
result = s3.get(path)
with zipfile.ZipFile(result.path, 'r') as zip_ref:
else: # Full dataset takes too long for the purpose of this tutorial.
raise NotImplementedError()

# 🚨🚨🚨 Do you want to ▶️ on ☁️☁️☁️?
# You need to be configured with a Metaflow AWS deployment to use this decorator.
# If you want to run locally, you can comment the `@batch` decorator out.
def train(self):
from import run_train
from hagrid.classifier.utils import get_device
import os

# Download the dataset onto the compute instance.
if not os.path.exists(self.DATA_ROOT):
print("Downloading images...")
self._download_data_from_s3(self.IMAGES, sample=True)
print("Downloading annotations...")
self._download_data_from_s3(self.ANNOTATIONS, sample=True)

# Train a model from available MODEL_NAME options from a checkpoint.
# There will be errors that happen if CHECKPOINT_PATH doesn't match MODEL_NAME.
# The user should know which checkpoint paths came from which models.
self.train_args = dict(
path_to_config = self.PATH_TO_CONFIG,
number_of_epochs = self.NUMBER_OF_EPOCHS,
device = get_device(),
checkpoint_path = self.CHECKPOINT_PATH,
model_name = self.MODEL_NAME,
tensorboard_s3_prefix = self.experiment_storage_prefix,
always_upload_best_model = True
_ = run_train(**self.train_args)

# Move the best model checkpoint to S3 if METAFLOW_DATASTORE_SYSROOT_S3 is available.
# See the comment in the start step about setting self.experiment_storage_prefix.
experiment_path = os.path.join("experiments", self.MODEL_NAME)
path_to_best_model = os.path.join(experiment_path, 'best_model.pth')
self.best_model_location = os.path.join(self.experiment_storage_prefix, path_to_best_model)
if self.best_model_location.startswith('s3://'):
with S3(s3root = self.experiment_storage_prefix) as s3:
s3.put_files([(path_to_best_model, path_to_best_model)])
print("Best model checkpoint saved at {}".format(self.best_model_location))

def end(self):
pass # You could do some fancy analytics, post-processing, or write a nice message here too!

if __name__ == '__main__':


We assume that you have taken the introductory tutorials or know the basics of Metaflow. This tutorial is directed toward learners familiar with model training concepts such as experiment tracking and model checkpointing. It is not required but it will be helpful if you are familiar with using Metaflow with AWS and PyTorch. You will need access to a Metaflow deployment for the lessons that use cloud resources. Reach out in Slack to get set up, or get a feel for what the full-feature set Metaflow can offer by signing up for your free hosted Sandbox. Alternatively, you can check out our beginner CV with Metaflow tutorial which requires no cloud deployment.

Tutorial Structure

The content includes the following:

Each episode contains a Metaflow script to run or a Jupyter notebook. The estimated time to read through the tutorial is 30 minutes to an hour. Running and adapting code will add a few more hours.