Skip to main content

Use Scikit-learn Estimators with Metaflow

Question

I have a scikit-learn workflow that I want to incorporate into a Metaflow flow. How can I include model fitting, prediction, feature transformations, and other capabilities enabled by scikit-learn in flow steps?

Solution

Note that this example uses a random forest classifier but the following applies to all scikit-learn models.

To turn this into a Metaflow flow, you first need to decide what your steps are going to be. In this case, there are distinct steps to:

  1. Load data.
  2. Instantiate a model.
  3. Train a model with cross-validation.

1Estimators to Flows

In general, this involves some design choices and we have some rules of thumb here. A benefit of separating flows into Metaflow steps is that you can resume failed computation from any step without having to recompute everything prior to the failed step which makes development much faster.

2Run Flow

This flow shows how to:

  • Import FlowSpec and step.
  • Include step-specific imports within each step.
  • Assign any data structures you wish to pass between steps to self.
  • Train a model and apply cross validation to evaluate it.
fit_sklearn_estimator.py
from metaflow import FlowSpec, step

class SklearnFlow(FlowSpec):

@step
def start(self):
from sklearn import datasets
self.iris = datasets.load_iris()
self.X = self.iris['data']
self.y = self.iris['target']
self.next(self.rf_model)

@step
def rf_model(self):
from sklearn.ensemble import RandomForestClassifier
self.clf = RandomForestClassifier(
n_estimators=10,
max_depth=None,
min_samples_split=2,
random_state=0
)
self.next(self.train)

@step
def train(self):
from sklearn.model_selection import cross_val_score
self.scores = cross_val_score(self.clf, self.X,
self.y, cv=5)
self.next(self.end)

@step
def end(self):
print("SklearnFlow is all done.")


if __name__ == "__main__":
SklearnFlow()

The example shows how to use the --with card CLI option to use a Metaflow card which produces HTML visualizations.

python fit_sklearn_estimator.py run --with card
    ...
[1663366789156643/end/4 (pid 5065)] Task is starting.
[1663366789156643/end/4 (pid 5065)] SklearnFlow is all done.
[1663366789156643/end/4 (pid 5065)] Task finished successfully.
...

3View Card

Now you can view the card for the train step using this command:

python fit_sklearn_estimator.py card view train

Further Reading