Skip to main content

Gradient Boosted Trees Flow

In this episode, you will reuse the general flow structure from Episode 1. Specifically, you will replace the random forest model in your flow with an XGBoost model.

1Write a Gradient Boosted Trees Flow

The flow has the following structure:

  • Parameter values are defined at the beginning of the class.
  • The start step loads and splits a dataset to be used in downstream tasks.
  • The train_xgb step fits an xgboost.XGBClassifier for the classification task using cross-validation.
  • The end step prints the accuracy scores for the classifier.
from metaflow import FlowSpec, step, Parameter

class GradientBoostedTreesFlow(FlowSpec):

random_state = Parameter("seed", default=12)
n_estimators = Parameter("n-est", default=10)
eval_metric = Parameter("eval-metric", default='mlogloss')
k_fold = Parameter("k", default=5)

def start(self):
from sklearn import datasets
self.iris = datasets.load_iris()
self.X = self.iris['data']
self.y = self.iris['target']

def train_xgb(self):
from xgboost import XGBClassifier
from sklearn.model_selection import cross_val_score
self.clf = XGBClassifier(
self.scores = cross_val_score(
self.clf, self.X, self.y, cv=self.k_fold)

def end(self):
import numpy as np
msg = "Gradient Boosted Trees Model Accuracy: {} \u00B1 {}%"
self.mean = round(100*np.mean(self.scores), 3)
self.std = round(100*np.std(self.scores), 3)
print(msg.format(self.mean, self.std))

if __name__ == "__main__":

2Run the Flow

python run
     Workflow starting (run-id 1666720725993465):
[1666720725993465/start/1 (pid 52705)] Task is starting.
[1666720725993465/start/1 (pid 52705)] Task finished successfully.
[1666720725993465/train_xgb/2 (pid 52708)] Task is starting.
[1666720725993465/train_xgb/2 (pid 52708)] Task finished successfully.
[1666720725993465/end/3 (pid 52714)] Task is starting.
[1666720725993465/end/3 (pid 52714)] Gradient Boosted Trees Model Accuracy: 96.667 ± 2.108%
[1666720725993465/end/3 (pid 52714)] Task finished successfully.

Note that XGBoost has two ways to train a booster model. This example uses XGBoost's scikit-learn API. If you use the XGBoost learning API you will have to use xgboost.DMatrix objects for data. These objects can not be serialized by pickle so cannot be stored using self directly. See this example to learn how to deal with cases where objects you want to self cannot be pickled.

In the last two episodes, you wrote flows to train random forest and XGBoost models. In the next episode, you will start to see the power of Metaflow as you merge these two flows and train the models in parallel. Metaflow allows you to run as many parallel tasks as you want, and the next lesson will provide a template for how to do this.