Skip to main content

Cross-validation in Parallel

Question

How can I use Metaflow to train a model for each cross-validation fold?

Solution

You can use Metaflow’s foreach pattern to run do this. A nice side effect is that structuring code this way makes it easy to run each parallel model training on whatever resources you need.

This example uses the vanilla KFold cross-validation implementation in scikit-learn, but using foreach can work with other lists such as any of the other cross-validation techniques scikit-learn offers to deal with imbalanced classes, time series data, and more.

This flow shows how to

  • Load data.
  • Generate KFold splits. See the highlighted lines in make_splits.
  • Fit and score a model for each fold.
  • Average the scores of the model and save for later analysis.
kfold_cv.py
from metaflow import FlowSpec, step, Parameter

class KFoldFlow(FlowSpec):

k = Parameter('k', default=5)

@step
def start(self):
from sklearn.datasets import load_wine
data = load_wine()
self.x = data['data']
self.y = data['target']
self.next(self.make_splits)

@step
def make_splits(self):
from sklearn.model_selection import KFold
kfold = KFold(n_splits=self.k)
self.split = []
for train_id, valid_id in kfold.split(self.x):
self.split.append((train_id, valid_id))
self.next(self.fit_and_score_model,
foreach="split")

@step
def fit_and_score_model(self):
from sklearn.tree import ExtraTreeClassifier
from sklearn.metrics import accuracy_score
self.model = ExtraTreeClassifier()
train_x = self.x[self.input[0]]
valid_x = self.x[self.input[1]]
train_y = self.y[self.input[0]]
valid_y = self.y[self.input[1]]
self.model.fit(train_x, train_y)
self.score = accuracy_score(
valid_y,
self.model.predict(valid_x)
)
self.next(self.average_scores)

@step
def average_scores(self, models):
import numpy as np
self.mean_score = np.mean([model.score
for model in models])
self.next(self.end)

@step
def end(self):
pass


if __name__ == "__main__":
KFoldFlow()
python kfold_cv.py run
    ...
[1654221286227325/make_splits/2 (pid 71301)] Task is starting.
[1654221286227325/make_splits/2 (pid 71301)] Foreach yields 5 child steps.
[1654221286227325/make_splits/2 (pid 71301)] Task finished successfully.
...
[1654221286227325/fit_and_score_model/3 (pid 71327)] Task is starting.
[1654221286227325/fit_and_score_model/4 (pid 71328)] Task is starting.
[1654221286227325/fit_and_score_model/5 (pid 71330)] Task is starting.
[1654221286227325/fit_and_score_model/6 (pid 71331)] Task is starting.
[1654221286227325/fit_and_score_model/7 (pid 71332)] Task is starting.
[1654221286227325/fit_and_score_model/5 (pid 71330)] Task finished successfully.
[1654221286227325/fit_and_score_model/3 (pid 71327)] Task finished successfully.
[1654221286227325/fit_and_score_model/4 (pid 71328)] Task finished successfully.
[1654221286227325/fit_and_score_model/6 (pid 71331)] Task finished successfully.
[1654221286227325/fit_and_score_model/7 (pid 71332)] Task finished successfully.
...

Further Reading