Skip to main content

Track Artifacts with Weights and Biases

Question

How can I track artifacts of my flows with Weights and Biases?

Solution

You can track flow artifacts using any Weights and Biases calls you already use. This can be especially useful if you want to track artifacts during the lifecycle of long-running tasks.

1Login to Weights and Biases

To run this code first go sign up for a Weights and Biases account and make sure you have logged in with your API key. It is recommend to store your key as an environment variable. In the example shown later, the Weights and Biases "entity" and "project" are also stored as environment variables:

export WANDB_API_KEY=<YOUR KEY>
export WANDB_ENTITY=<YOUR USERNAME>
export WANDB_PROJECT=<YOUR PROJECT>

Then you can install and log in to the Weights and Biases Python client:

pip install wandb

If you don't set the WANDB_API_KEY environment variable, you will need to paste your key after running:

wandb login

2Define Logging Function

Here is a function that takes in arguments from a dataset and classification model and logs with Weights and Biases. It uses Weights and Biases Scikit-learn integration, but you can replace it with arbitrary logging functions relevant to your workflow.

wandb_helpers.py
import os
import wandb
def plot_results(X_train, y_train, X_test, y_test,
y_pred, y_probs, clf, labels):
wandb.init(entity=os.getenv("WANDB_ENTITY"),
project=os.getenv("WANDB_PROJECT"),
mode="offline")
wandb.sklearn.plot_class_proportions(y_train,
y_test,
labels)
wandb.sklearn.plot_learning_curve(clf,
X_train,
y_train)
wandb.sklearn.plot_roc(y_test, y_probs, labels)
wandb.sklearn.plot_precision_recall(y_test,
y_probs,
labels)
wandb.sklearn.plot_feature_importances(clf)
wandb.sklearn.plot_classifier(
clf, X_train, X_test, y_train, y_test, y_pred,
y_probs, labels, is_binary=True,
model_name='RandomForest'
)
wandb.finish()

3Run Flow

The flow shows how to:

  • Load data in the start step.
  • Build a model and call a custom logging function in the model step.
    • Call the custom logging function plot_results to
    • This step uses Metaflow's @environment decorator to pass environment variables relevant to Weights and Biases into the step. This is useful when you want to track a step run on a remote machine using a Metaflow decorator like @batch or @kubernetes.
track_with_wandb_custom.py
from metaflow import FlowSpec, step, environment, batch, conda_base
import os
import wandb
from wandb_helpers import plot_results

@conda_base(libraries={"wandb": "0.12.15", "scikit-learn": "1.0.2",
"pandas": "1.4.2"})
class TrackPlotsFlow(FlowSpec):

@step
def start(self):
from sklearn import datasets
from sklearn.model_selection import train_test_split
self.iris = datasets.load_iris()
self.X = self.iris['data']
self.y = self.iris['target']
self.labels = self.iris['target_names']
split = train_test_split(self.X, self.y,
test_size=0.2)
self.X_train = split[0]
self.X_test = split[1]
self.y_train = split[2]
self.y_test = split[3]
self.next(self.model)

# Copy env vars to tasks on a different machine.
@environment(vars={
"WANDB_API_KEY": os.getenv("WANDB_API_KEY"),
"WANDB_NAME": "Plot RandomForestClassifier",
"WANDB_ENTITY": os.getenv("WANDB_ENTITY"),
"WANDB_PROJECT": os.getenv("WANDB_PROJECT")
})
@batch(cpu=2)
@step
def model(self):
from sklearn.ensemble import RandomForestClassifier
self.clf = RandomForestClassifier(
n_estimators=10, max_depth=None,
min_samples_split=2, random_state=0
)
from sklearn.model_selection import cross_val_score
self.clf.fit(self.X_train, self.y_train)
self.y_pred = self.clf.predict(self.X_test)
self.y_probs = self.clf.predict_proba(
self.X_test
)
plot_results(self.X_train, self.y_train,
self.X_test, self.y_test,
self.y_pred, self.y_probs,
self.clf, self.labels)
self.next(self.end)

@step
def end(self):
print("Flow is all done.")

if __name__ == "__main__":
TrackPlotsFlow()
python track_with_wandb_custom.py --environment=conda run
Workflow starting (run-id 110):
[110/start/494 (pid 11536)] Task is starting.
[110/start/494 (pid 11536)] Task finished successfully.
[110/model/495 (pid 11746)] Task is starting.
[110/model/495 (pid 11746)] [6d6302c8-20ee-4212-9d09-e67fc49beead] Task is starting (status SUBMITTED)...
[110/model/495 (pid 11746)] [6d6302c8-20ee-4212-9d09-e67fc49beead] Task is starting (status RUNNABLE)...
[110/model/495 (pid 11746)] [6d6302c8-20ee-4212-9d09-e67fc49beead] Task is starting (status RUNNING)...
[110/model/495 (pid 11746)] [6d6302c8-20ee-4212-9d09-e67fc49beead] Setting up task environment.
[110/model/495 (pid 11746)] [6d6302c8-20ee-4212-9d09-e67fc49beead] Downloading code package...
[110/model/495 (pid 11746)] [6d6302c8-20ee-4212-9d09-e67fc49beead] Code package downloaded.
[110/model/495 (pid 11746)] [6d6302c8-20ee-4212-9d09-e67fc49beead] Task is starting.
[110/model/495 (pid 11746)] [6d6302c8-20ee-4212-9d09-e67fc49beead] Bootstrapping environment...
[110/model/495 (pid 11746)] [6d6302c8-20ee-4212-9d09-e67fc49beead] Environment bootstrapped.
[110/model/495 (pid 11746)] [6d6302c8-20ee-4212-9d09-e67fc49beead] Task finished with exit code 0.
[110/model/495 (pid 11746)] Task finished successfully.
[110/end/496 (pid 12169)] Task is starting.
[110/end/496 (pid 12169)] Flow is all done.
[110/end/496 (pid 12169)] Task finished successfully.
Done!

Further Reading

For help or feedback, please join Metaflow Slack. To suggest an article, you may open an issue on GitHub.

Join Slack