Cross-validation in Parallel


How can I use Metaflow to train a model for each cross-validation fold?


You can use Metaflow’s foreach pattern to run do this. A nice side effect is that structuring code this way makes it easy to run each parallel model training on whatever resources you need.

This example uses the vanilla KFold cross-validation implementation in scikit-learn, but using foreach can work with other lists such as any of the other cross-validation techniques scikit-learn offers to deal with imbalanced classes, time series data, and more.

This flow shows how to

  • Load data.
  • Generate KFold splits. See the highlighted lines in make_splits.
  • Fit and score a model for each fold.
  • Average the scores of the model and save for later analysis.
from metaflow import FlowSpec, step, Parameter

class KFoldFlow(FlowSpec):

k = Parameter('k', default=5)

def start(self):
from sklearn.datasets import load_wine
data = load_wine()
self.x = data['data']
self.y = data['target']

def make_splits(self):
from sklearn.model_selection import KFold
kfold = KFold(n_splits=self.k)
self.split = []
for train_id, valid_id in kfold.split(self.x):
self.split.append((train_id, valid_id)),

def fit_and_score_model(self):
from sklearn.tree import ExtraTreeClassifier
from sklearn.metrics import accuracy_score
self.model = ExtraTreeClassifier()
train_x = self.x[self.input[0]]
valid_x = self.x[self.input[1]]
train_y = self.y[self.input[0]]
valid_y = self.y[self.input[1]], train_y)
self.score = accuracy_score(

def average_scores(self, models):
import numpy as np
self.mean_score = np.mean([model.score
for model in models])

def end(self):

if __name__ == "__main__":
python run
[1654221286227325/make_splits/2 (pid 71301)] Task is starting.
[1654221286227325/make_splits/2 (pid 71301)] Foreach yields 5 child steps.
[1654221286227325/make_splits/2 (pid 71301)] Task finished successfully.
[1654221286227325/fit_and_score_model/3 (pid 71327)] Task is starting.
[1654221286227325/fit_and_score_model/4 (pid 71328)] Task is starting.
[1654221286227325/fit_and_score_model/5 (pid 71330)] Task is starting.
[1654221286227325/fit_and_score_model/6 (pid 71331)] Task is starting.
[1654221286227325/fit_and_score_model/7 (pid 71332)] Task is starting.
[1654221286227325/fit_and_score_model/5 (pid 71330)] Task finished successfully.
[1654221286227325/fit_and_score_model/3 (pid 71327)] Task finished successfully.
[1654221286227325/fit_and_score_model/4 (pid 71328)] Task finished successfully.
[1654221286227325/fit_and_score_model/6 (pid 71331)] Task finished successfully.
[1654221286227325/fit_and_score_model/7 (pid 71332)] Task finished successfully.

