Skip to main content

Track Artifacts with Weights and Biases

Question

How can I track artifacts of my flows with Weights and Biases?

Solution

You can track flow artifacts using any Weights and Biases calls you already use. This can be especially useful if you want to track artifacts during the lifecycle of long-running tasks.

1Login to Weights and Biases

To run this code first go sign up for a Weights and Biases account and make sure you have logged in with your API key. It is recommended that you store the key as an environment variable. In the example shown later, the Weights and Biases "entity" and "project" are also stored as environment variables:

export WANDB_API_KEY=<YOUR KEY>
export WANDB_ENTITY=<YOUR USERNAME>
export WANDB_PROJECT=<YOUR PROJECT>

Then you can install and log in to the Weights and Biases Python client:

pip install wandb

If you don't set the WANDB_API_KEY environment variable, you will need to paste your key after running:

wandb login

2Define Logging Function

Here is a function that takes in arguments from a dataset and classification model and logs with Weights and Biases. It uses Weights and Biases Scikit-learn integration, but you can replace it with arbitrary logging functions relevant to your workflow.

wandb_helpers.py
import os
import wandb
def plot_results(X_train, y_train, X_test, y_test,
y_pred, y_probs, clf, labels):
wandb.init(entity=os.getenv("WANDB_ENTITY"), project=os.getenv("WANDB_PROJECT"))
wandb.sklearn.plot_class_proportions(y_train,
y_test,
labels)
wandb.sklearn.plot_learning_curve(clf,
X_train,
y_train)
wandb.sklearn.plot_roc(y_test, y_probs, labels)
wandb.sklearn.plot_precision_recall(y_test,
y_probs,
labels)
wandb.sklearn.plot_feature_importances(clf)
wandb.sklearn.plot_classifier(
clf, X_train, X_test, y_train, y_test, y_pred,
y_probs, labels, is_binary=True,
model_name='RandomForest'
)
wandb.finish()

3Run Flow

The flow shows how to:

  • Load data in the start step.
  • Build a model and call a custom logging function in the model step.
    • Call the custom logging function plot_results to
    • This step uses Metaflow's @environment decorator to pass environment variables relevant to Weights and Biases into the step. This is useful when you want to track a step run on a remote machine using a Metaflow decorator like @batch or @kubernetes.
track_with_wandb_custom.py
from metaflow import FlowSpec, step, environment, batch, conda_base
import os
import wandb
from wandb_helpers import plot_results

@conda_base(libraries={"wandb": "0.12.15", "scikit-learn": "1.0.2", "pandas": "1.4.2"})
class TrackPlotsFlow(FlowSpec):

@step
def start(self):
from sklearn import datasets
from sklearn.model_selection import train_test_split
self.iris = datasets.load_iris()
self.X = self.iris['data']
self.y = self.iris['target']
self.labels = self.iris['target_names']
split = train_test_split(self.X, self.y,
test_size=0.2)
self.X_train = split[0]
self.X_test = split[1]
self.y_train = split[2]
self.y_test = split[3]
self.next(self.model)

# Copy env vars to tasks on a different machine.
@environment(vars={
"WANDB_NAME": "Plot RandomForestClassifier",
"WANDB_API_KEY": os.getenv("WANDB_API_KEY"),
"WANDB_PROJECT": os.getenv("WANDB_PROJECT"),
"WANDB_ENTITY": os.getenv("WANDB_ENTITY"),
})
@batch(cpu=2)
@step
def model(self):
from sklearn.ensemble import RandomForestClassifier
self.clf = RandomForestClassifier(
n_estimators=10, max_depth=None,
min_samples_split=2, random_state=0
)
from sklearn.model_selection import cross_val_score
self.clf.fit(self.X_train, self.y_train)
self.y_pred = self.clf.predict(self.X_test)
self.y_probs = self.clf.predict_proba(
self.X_test
)
plot_results(self.X_train, self.y_train,
self.X_test, self.y_test,
self.y_pred, self.y_probs,
self.clf, self.labels)
self.next(self.end)

@step
def end(self):
print("Flow is all done.")

if __name__ == "__main__":
TrackPlotsFlow()
python track_with_wandb_custom.py --environment=conda run
     Workflow starting (run-id 216534):
[216534/start/1331131 (pid 6940)] Task is starting.
[216534/start/1331131 (pid 6940)] Task finished successfully.
[216534/model/1331132 (pid 7279)] Task is starting.
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Task is starting (status SUBMITTED)...
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Task is starting (status RUNNABLE)...
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Task is starting (status RUNNABLE)...
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Task is starting (status RUNNABLE)...
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Task is starting (status RUNNABLE)...
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Task is starting (status RUNNABLE)...
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Task is starting (status STARTING)...
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Task is starting (status RUNNING)...
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Setting up task environment.
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Downloading code package...
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Code package downloaded.
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Task is starting.
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Bootstrapping virtual environment...
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Environment bootstrapped.
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] eddiem ob-how-to
[216534/model/1331132 (pid 7279)] [49f0df00-e3e4-4ef8-ae78-300145965ec9] Task finished with exit code 0.
[216534/model/1331132 (pid 7279)] Task finished successfully.
[216534/end/1331133 (pid 12751)] Task is starting.
[216534/end/1331133 (pid 12751)] Flow is all done.
[216534/end/1331133 (pid 12751)] Task finished successfully.
Done!

Further Reading