Skip to main content

Use PyTorch with Metaflow

Question

How do I build and fit a PyTorch model in a Metaflow flow?

Solution

PyTorch has many use cases and model types. It is important to think about designing your flow based on which steps you want to run in the same compute environment. For example, if you want to run training with a GPU and validation with a CPU you can separate these into steps and have them use different resources using Metaflow decorators like @batch and @kubernetes.

To structure the model as a flow you need to import Metaflow's FlowSpec object and step object. The flow in this example uses Metaflow's integration with conda for dependency management. In this case a single conda environment is created for all steps in.

The flow shows how to:

  • Read in hyperparameters of the training process with Metaflow's Parameter.
  • Fetch the MNIST dataset and store it in PyTorch DataLoader objects to be used in downstream flow tasks.
  • Train the CNN and save the model as a flow artifact.
  • Evaluate the model in a separate step.
    • Separating training and evaluation steps is useful when you want to run different flow steps on different hardware.
fit_torch.py
from metaflow import (FlowSpec, step, Parameter, 
conda_base)
from torch_utilities import (train, test, Net,
get_data_loaders)

@conda_base(libraries={"pytorch":"1.11.0",
"torchvision":"0.12.0"},
python="3.8")
class TorchFlow(FlowSpec):

lr = Parameter('lr', default=0.01)
epochs = Parameter('epochs', default=1)

@step
def start(self):
self.next(self.get_data)

@step
def get_data(self):
import torch
train_dataset, train_args = get_data_loaders()
test_dataset, test_args = get_data_loaders(
"test")
self.train_loader = torch.utils.data.DataLoader(
train_dataset, **train_args)
self.test_loader = torch.utils.data.DataLoader(
test_dataset, **test_args)
self.next(self.fit_model)

@step
def fit_model(self):
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
self.model = Net()
optimizer = optim.Adadelta(
self.model.parameters(), lr=self.lr)
scheduler = StepLR(optimizer, step_size=1)
for epoch in range(1, self.epochs + 1):
train(self.model, self.train_loader,
optimizer, epoch)
_ = test(self.model, self.test_loader)
scheduler.step()
self.next(self.evaluate_model)

@step
def evaluate_model(self):
self.test_score = test(self.model,
self.test_loader)
print(f"Model scored {100*self.test_score}%")
self.next(self.end)

@step
def end(self):
pass

if __name__ == "__main__":
TorchFlow()

In this flow you see a CNN example from the PyTorch examples that is trained on the MNIST digit classification task. Here is the torch_utilities.py script used in the flow. It contains PyTorch definitions for a convolutional neural network, a training function, a testing function, and a data loading function.

torch_utilities.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import datasets, transforms

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output

def train(model, train_loader, optimizer, epoch):
model.train()
for idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if idx * len(data) % 10000 == 0:
out = 'Train Epoch: ' + \
'{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, idx * len(data), len(train_loader.dataset),
100. * idx / len(train_loader), loss.item())
print(out)

def test(model, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += F.nll_loss(
output, target,
reduction='sum'
).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(
target.view_as(pred)
).sum().item()
return correct / len(test_loader.dataset)


def get_data_loaders(name="train"):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),
(0.3081,))
])
if name=="train":
dataset = datasets.MNIST('../data',
train=True,
download=True,
transform=transform)
train_args = {'batch_size': 32}
return dataset, train_args
elif name=="test":
dataset = datasets.MNIST('../data', train=False,
transform=transform)
test_args = {'batch_size': 32}
return dataset, test_args

Run Flow

With the torch_utilities.py dependencies and flow defined in fit_torch.py, you can run the script with the following command:

python fit_torch.py --environment=conda run
     Workflow starting (run-id 634):
[634/start/3294 (pid 3782)] Task is starting.
[634/start/3294 (pid 3782)] Task finished successfully.
[634/get_data/3295 (pid 4121)] Task is starting.
100.0%05-25 14:54:17.557 [634/get_data/3295 (pid 4121)] 0.0%
[634/get_data/3295 (pid 4121)]
102.8%05-25 14:54:18.083 [634/get_data/3295 (pid 4121)] 3.5%
[634/get_data/3295 (pid 4121)]
100.0%05-25 14:54:24.052 [634/get_data/3295 (pid 4121)] 0.1%
[634/get_data/3295 (pid 4121)]
112.7%05-25 14:54:24.307 [634/get_data/3295 (pid 4121)] 22.5%
[634/get_data/3295 (pid 4121)]
[634/get_data/3295 (pid 4121)] Task finished successfully.
[634/fit_model/3296 (pid 4267)] Task is starting.
[634/fit_model/3296 (pid 4267)] Train Epoch: 1 [0/60000 (0%)] Loss: 2.299288
[634/fit_model/3296 (pid 4267)] Train Epoch: 1 [20000/60000 (33%)] Loss: 0.907390
[634/fit_model/3296 (pid 4267)] Train Epoch: 1 [40000/60000 (67%)] Loss: 0.384034
[634/fit_model/3296 (pid 4267)] Task finished successfully.
[634/evaluate_model/3297 (pid 4295)] Task is starting.
[634/evaluate_model/3297 (pid 4295)] Model scored 92.52%
[634/evaluate_model/3297 (pid 4295)] Task finished successfully.
[634/end/3298 (pid 4316)] Task is starting.
[634/end/3298 (pid 4316)] Task finished successfully.
Done!

Further Reading