Skip to main content

Pass XGBoost DMatrix Between Metaflow Steps

Question

XGBoost uses a data structure called a DMatrix, which I cannot assign to self because it is not pickleable. How do I pass a DMatrix between steps?

Solution

The easiest solution is to use save_binary method and use the file name you save to as a Metaflow artifact.

However, if you want to serialize the contents of a DMatrix so you can access the same object across steps, you have to perform the following workaround:

1Define Helper Functions

Saving The Data

  1. First serialize the data in DMatrix to disk by using the save_binary method.
  2. Read the data into a variable and assign it to self. This is the first highlighted section in the code below.

Loading the Data

  1. Save the binary data you stored in self from the previous step to disk.
  2. Load the data into xgb.DMatrix using the file name. This is the second highlighted section in the code below.

2Run Flow

This flow shows how to use save_matrix and write_binary so you can serialize an xgboost.DMatrix.

We can run the flow and see that the DMatrix contents are propagated as expected:

pass_dmatrix_between_steps.py
from metaflow import FlowSpec, step, Parameter 
import xgboost as xgb
import numpy as np
from tempfile import NamedTemporaryFile

def save_matrix(dmatrix, file_name):
dmatrix.save_binary(file_name)
with open(file_name, 'rb') as f:
xgb_data = f.read()
return xgb_data

def write_binary(xgb_data, file_name):
with open(file_name, 'wb') as f:
f.write(xgb_data)

class SerializeXGBDataFlow(FlowSpec):

file_name = Parameter('file_name',
default='xgb_data.xgb')

@step
def start(self):
dmatrix = xgb.DMatrix(np.random.rand(5, 10))
self.xgb_data = save_matrix(dmatrix,
self.file_name)
self.next(self.end)

@step
def end(self):
write_binary(self.xgb_data, self.file_name)
data = xgb.DMatrix(self.file_name)
print(f'there are {data.num_row()} ' + \
'rows in the data.')

if __name__ == '__main__':
SerializeXGBDataFlow()
python pass_dmatrix_between_steps.py run
...
[1654221299103134/end/2 (pid 71559)] Task is starting.
[1654221299103134/end/2 (pid 71559)] there are 5 rows in the data.
[1654221299103134/end/2 (pid 71559)] Task finished successfully.
...

Further Reading

For help or feedback, please join Metaflow Slack. To suggest an article, you may open an issue on GitHub.

Join Slack