Skip to main content

Grid Search with Metaflow


How can I do a grid search with Metaflow using scikit-learn's ParameterGrid?


This flow uses scikit-learn's ParameterGrid object to evaluate each grid point on a different worker using Metaflow's foreach pattern. This parallelizes workers across processes on a single node, or you can use Metaflow's @batch decorator to execute on different machines.
from metaflow import FlowSpec, step

class ParamGridFlow(FlowSpec):

def start(self):
from sklearn.datasets import load_iris
data = load_iris()
self.X, self.y = data['data'], data['target']

def make_grid(self):
from sklearn.model_selection import ParameterGrid
param_values = {'max_depth': [2, 4, 8, 16],
'criterion': ['entropy', 'gini']}
self.grid_points = list(
# evaluate each in cross product of ParameterGrid.,

def evaluate_model(self):
from sklearn.tree import ExtraTreeClassifier
from sklearn.model_selection import cross_val_score
self.clf = ExtraTreeClassifier(**self.input)
self.scores = cross_val_score(self.clf, self.X,
self.y, cv=5)

def join(self, inputs):
import numpy as np
self.mean_scores = [np.mean(model.scores)
for model in inputs]

def end(self):

if __name__ == "__main__":
python run
[1654221281136541/make_grid/2 (pid 71177)] Task is starting.
[1654221281136541/make_grid/2 (pid 71177)] Foreach yields 8 child steps.
[1654221281136541/make_grid/2 (pid 71177)] Task finished successfully.
[1654221281136541/join/11 (pid 71261)] Task is starting.
[1654221281136541/join/11 (pid 71261)] Task finished successfully.
[1654221281136541/end/12 (pid 71284)] Task is starting.
[1654221281136541/end/12 (pid 71284)] Task finished successfully.

Further Reading