Table of Contents

Indestructible training and fine-tuning with @checkpoint

Metaflow now provides a convenient and flexible way to persist the progress of long-running tasks  - especially training and fine-tuning - using the new @checkpoint decorator. Checkpoints are particularly useful when working with spot instances, which recently received enhanced support in Metaflow as well. Read on to get an overview of the new features, and check out the documentation when you're ready to get started!

If you have run long-running tasks in the cloud (or an on-prem data center), you've likely encountered your fair share of unexpected failures - instances disappearing, hardware failing, or fellow humans fat-fingering configuration causing workloads to fail on the fly. Restarting tasks and losing hours (or more) of progress is certainly inconvenient, but with GPU-intensive AI workloads, it can also become a significant cost concern.

One of the core features of Metaflow is comprehensive snapshotting of flow progress through artifacts, enabling you to resume computation at any step. However, since artifacts are only persisted upon task completion, they do not support saving intermediate progress during task execution, such as checkpointing partially trained models in long-running training steps.

For the past months, we have been working on battle-hardening a new @checkpoint decorator to plug this gap. As the name implies, it allows you to checkpoint progress within a task, integrating with popular ML/AI frameworks like PyTorch, XGBoost, as well as fine-tuning libraries for LLM/GenAI models such as LLaMAFactory and HuggingFace. Today, we are happy to release comprehensive documentation for @checkpoint, so you can start using it confidently in your own projects.

Introducing indestructible training and finetuning steps

It takes only a few lines of code to harden training steps using the @checkpoint decorator - take a look at the new documentation and an example repository for reference. When combined with @retry, training steps are able to recover from failures and continue training with a minimal delay.

This short clip showcases @checkpoint in action: We begin training an XGBoost model on a cloud instance, tracking progress through a card. Then, we unleash our inner chaos monkey, repeatedly terminating the task - yet training seamlessly recovers and continues as if nothing happened:

What’s the (@check)point?

The concept of checkpointing has been around for decades - no seriously large model is trained without them. Is there anything new or noteworthy about the @checkpoint decorator specifically?

Think of checkpoints as a form of persisted state. Managing state - storing, loading, and tracking it - is a non-trivial challenge. While most ML libraries make it easy to save model state to disk, they leave you to figure out how to persist and transfer checkpoints across instances, determine which checkpoints to load, and handle the loading process efficiently.

Imagine a team of developers rapidly iterating on models, generating a continuous stream of checkpoints from multiple concurrent experiments. Things can quickly become chaotic. This is where @checkpoint shines - it seamlessly integrates with Metaflow’s built-in mechanisms to keep experiments and production runs organized. Checkpoints

  • Integrate with Metaflow infrastructure efficiently, datastore in particular, so there’s no need to worry about infrastructure for persisting and transferring checkpoints across instances. 
  • Work consistently in all environments, from local development to @batch, @kubernetes, and even on Slurm clusters, thanks to @slurm. No need to deal with expensive and finicky distributed filesystems.
  • Are versioned and tracked alongside Metaflow runs and artifacts, benefiting from Metaflow namespaces, branched deployments, and easy access with the Client API.
  • Are readily observable in the UI through a built-in @card that visualizes metadata about checkpoints in real-time.
  • Enable collaborative development - you can seamlessly share them between colleagues or between a production cluster handling heavy workloads and development environments, without worrying about using the wrong checkpoints or interfering with others.
  • Come with easy APIs for accessing checkpoints inside and outside of flows.
  • Support demanding distributed training use cases with the flexibility to customize behavior, adapting to even complex training setups.

Take a look at the @checkpoint documentation for details.

Smart checkpointing on spot instances

Recently, we introduced another related feature in Metaflow: the ability to proactively respond to spot instance terminations. Many Metaflow users leverage spot instances for cost-efficient workload execution, with @retry handling automatic retries when instances are terminated.

Fortunately many clouds deliver a warning to a spot instance shortly before it is going to be terminated, allowing the task to perform a quick cleanup - or checkpointing - before it gets terminated, ensuring that the task can recover quickly. By combining spot termination notices, @checkpoint, and @retry, we can implement smart checkpointing policies - minimizing overhead by checkpointing infrequently while ensuring a final checkpoint is created precisely when needed.

Here’s an example of a spot-aware checkpointer for XGBoost. It creates checkpoints at regular intervals, unless a termination notice is received, in which case we checkpoint immediately and terminates the task, to avoid wasting time and money:

import os
import pickle
from metaflow import current
import xgboost

class SpotCheckpointer(xgboost.callback.TrainingCallback):

    @classmethod
    def _path(cls):
        return os.path.join(current.checkpoint.directory, 'xgb_cp.pkl')

    def __init__(self, interval=100):
        self._interval = interval

    def after_iteration(self, model, epoch, evals_log):
        is_terminating = os.path.exists(current.spot_termination_notice)
        if (epoch > 0 and epoch % self._interval == 0) or is_terminating:
            with open(self._path(), 'wb') as f:
                pickle.dump(model, f)
            current.checkpoint.save()
            if is_terminating:
                raise Exception("Spot instance terminating")

    @classmethod
    def load(cls):
        with open(cls._path(), 'rb') as f:
            return pickle.load(f)  
            

To test this in action, plug SpotCheckpointer into the CheckpointXGBoost example found in the documentation. You can adapt the same pattern easily for other frameworks including PyTorch or fine-tuning libraries, following examples in our reference repository.

Score some easy (check)points at home

It couldn’t be easier to get started:

pip install metaflow-checkpoint

and follow examples in the documentation and a reference repository. If you have questions or ideas for improving @checkpoint, we'd love to hear from you on the Metaflow Slack!

Start building today

Join our office hours for a live demo! Whether you're curious about Outerbounds or have specific questions - nothing is off limits.


We can't wait to meet you soon! Keep an eye out for a confirmation email with the deets.
Oops! Something went wrong while submitting the form.