Skip to main content

Grid Search with Metaflow

Question

How can I do a grid search with Metaflow using scikit-learn's ParameterGrid?

Solution

This flow uses scikit-learn's ParameterGrid object to evaluate each grid point on a different worker using Metaflow's foreach pattern. This parallelizes workers across processes on a single node, or you can use Metaflow's @batch decorator to execute on different machines.

foreach_param_grid.py
from metaflow import FlowSpec, step

class ParamGridFlow(FlowSpec):

@step
def start(self):
from sklearn.datasets import load_iris
data = load_iris()
self.X, self.y = data['data'], data['target']
self.next(self.make_grid)

@step
def make_grid(self):
from sklearn.model_selection import ParameterGrid
param_values = {'max_depth': [2, 4, 8, 16],
'criterion': ['entropy', 'gini']}
self.grid_points = list(
ParameterGrid(param_values)
)
# evaluate each in cross product of ParameterGrid.
self.next(self.evaluate_model,
foreach='grid_points')

@step
def evaluate_model(self):
from sklearn.tree import ExtraTreeClassifier
from sklearn.model_selection import cross_val_score
self.clf = ExtraTreeClassifier(**self.input)
self.scores = cross_val_score(self.clf, self.X,
self.y, cv=5)
self.next(self.join)

@step
def join(self, inputs):
import numpy as np
self.mean_scores = [np.mean(model.scores)
for model in inputs]
self.next(self.end)

@step
def end(self):
pass

if __name__ == "__main__":
ParamGridFlow()
python foreach_param_grid.py run
...
[1654221281136541/make_grid/2 (pid 71177)] Task is starting.
[1654221281136541/make_grid/2 (pid 71177)] Foreach yields 8 child steps.
[1654221281136541/make_grid/2 (pid 71177)] Task finished successfully.
...
[1654221281136541/join/11 (pid 71261)] Task is starting.
[1654221281136541/join/11 (pid 71261)] Task finished successfully.
...
[1654221281136541/end/12 (pid 71284)] Task is starting.
[1654221281136541/end/12 (pid 71284)] Task finished successfully.
...

Further Reading

For help or feedback, please join Metaflow Slack. To suggest an article, you may open an issue on GitHub.

Join Slack