Skip to main content

Beginner Computer Vision: Episode 4

Now that you have a stable workflow for training and evaluating models, it is time to iterate. In this lesson, you will see how to build a flow that does a hyperparameter search. The search consists of checking the model's number of neurons in each layer. We will tune the neural network to have a different number of convolutional filters in each network. Metaflow will parallelize the training of one model for each hyperparameter configuration automatically while versioning the results of all training runs. After running the flow you will be able to fetch all hyperparameter values and metrics scores in any Python environment.

1Write a Tuning Flow

This flow shows how you can tune the cnn model. The flow includes

  • A start step that loads data.
  • A train step that trains and scores a model for each hyperparameter configuration.
    • Metaflow's foreach pattern is used to make this happen in parallel locally and in the cloud.
  • A gather_scores step that joins the results from each modeling step and stores the results in a Metaflow card visual.
  • An end step that saves the best model.

tuning_flow.py
from metaflow import FlowSpec, step, Flow, current, card
from metaflow.cards import Image, Table
from tensorflow import keras
from models import ModelOperations

class TuningFlow(FlowSpec, ModelOperations):

best_model_location = ("best_tuned_model")
num_pixels = 28 * 28
kernel_initializer = 'normal'
optimizer = 'adam'
loss = 'categorical_crossentropy'
metrics = [
'accuracy',
'precision at recall'
]
input_shape = (28, 28, 1)
kernel_size = (3, 3)
pool_size = (2, 2)
p_dropout = 0.5
epochs = 5
batch_size = 64
verbose = 2

@step
def start(self):
import numpy as np
self.num_classes = 10
((x_train, y_train),
(x_test, y_test)) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
self.x_train = np.expand_dims(x_train, -1)
self.x_test = np.expand_dims(x_test, -1)
self.y_train = keras.utils.to_categorical(
y_train, self.num_classes)
self.y_test = keras.utils.to_categorical(
y_test, self.num_classes)
self.param_config = [
{"hidden_conv_layer_sizes": [16, 32]},
{"hidden_conv_layer_sizes": [16, 64]},
{"hidden_conv_layer_sizes": [32, 64]},
{"hidden_conv_layer_sizes": [32, 128]},
{"hidden_conv_layer_sizes": [64, 128]}
]
self.next(self.train, foreach='param_config')

@step
def train(self):
from neural_net_utils import plot_learning_curves
self.model = self.make_cnn(
self.input['hidden_conv_layer_sizes'])
self.history, self.scores = self.fit_and_score(
self.x_train, self.x_test)
self._name = 'CNN'
self.plots = [
Image.from_matplotlib(p) for p in
plot_learning_curves(
self.history,
'Hidden Layers - ' + ', '.join([
str(i) for i in
self.input['hidden_conv_layer_sizes']
])
)
]
self.next(self.gather_scores)

@card
@step
def gather_scores(self, models):
import pandas as pd
self.max_class = models[0].y_train
results = {
'hidden conv layer sizes': [],
'model': [],
'test loss': [],
**{metric: [] for metric in self.metrics}
}
max_seen_acc = 0
rows = []
for model in models:
results['model'].append(model._name)
results['test loss'].append(model.scores[0])
for i, metric in enumerate(self.metrics):
results[metric].append(model.scores[i+1])
results['hidden conv layer sizes'].append(
','.join([
str(i) for i in model.input[
'hidden_conv_layer_sizes'
]
])
)
# A simple rule for determining the best model.
# In production flows you need to think carefully
# about how this kind of rule maps to your objectives.
if model.scores[1] > max_seen_acc:
self.best_model = model.model
max_seen_acc = model.scores[1]
rows.append(model.plots)

current.card.append(Table(rows))
self.results = pd.DataFrame(results)
self.next(self.end)

@step
def end(self):
self.best_model.save(self.best_model_location)

if __name__ == '__main__':
TuningFlow()

2Run the Tuning Flow

python tuning_flow.py run
     Workflow starting (run-id 1666721523161525):
[1666721523161525/start/1 (pid 53367)] Task is starting.
[1666721523161525/start/1 (pid 53367)] Foreach yields 5 child steps.
[1666721523161525/start/1 (pid 53367)] Task finished successfully.
[1666721523161525/train/2 (pid 53375)] Task is starting.
[1666721523161525/train/3 (pid 53376)] Task is starting.
[1666721523161525/train/4 (pid 53377)] Task is starting.
[1666721523161525/train/5 (pid 53378)] Task is starting.
[1666721523161525/train/6 (pid 53379)] Task is starting.
[1666721523161525/train/3 (pid 53376)] 742: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
[1666721523161525/train/6 (pid 53379)] 279: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
[1666721523161525/train/3 (pid 53376)] Epoch 1/5
[1666721523161525/train/4 (pid 53377)] 553: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
[1666721523161525/train/2 (pid 53375)] 816: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
[1666721523161525/train/6 (pid 53379)] Epoch 1/5
[1666721523161525/train/4 (pid 53377)] Epoch 1/5
[1666721523161525/train/2 (pid 53375)] Epoch 1/5
[1666721523161525/train/5 (pid 53378)] 924: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
[1666721523161525/train/5 (pid 53378)] Epoch 1/5
[1666721523161525/train/2 (pid 53375)] 938/938 - 14s - loss: 0.3722 - accuracy: 0.8858 - precision_at_recall: 0.6490 - val_loss: 0.0966 - val_accuracy: 0.9721 - val_precision_at_recall: 0.9819 - 14s/epoch - 15ms/step
[1666721523161525/train/3 (pid 53376)] 938/938 - 17s - loss: 0.2737 - accuracy: 0.9170 - precision_at_recall: 0.7856 - val_loss: 0.0725 - val_accuracy: 0.9771 - val_precision_at_recall: 0.9890 - 17s/epoch - 18ms/step
[1666721523161525/train/3 (pid 53376)] Epoch 2/5
[1666721523161525/train/4 (pid 53377)] 938/938 - 25s - loss: 0.2661 - accuracy: 0.9174 - precision_at_recall: 0.7964 - val_loss: 0.0650 - val_accuracy: 0.9802 - val_precision_at_recall: 0.9930 - 25s/epoch - 27ms/step
[1666721523161525/train/4 (pid 53377)] Epoch 2/5
[1666721523161525/train/2 (pid 53375)] Epoch 2/5
[1666721523161525/train/2 (pid 53375)] 938/938 - 13s - loss: 0.1234 - accuracy: 0.9624 - precision_at_recall: 0.9658 - val_loss: 0.0606 - val_accuracy: 0.9821 - val_precision_at_recall: 0.9934 - 13s/epoch - 14ms/step
[1666721523161525/train/3 (pid 53376)] 938/938 - 16s - loss: 0.0948 - accuracy: 0.9709 - precision_at_recall: 0.9809 - val_loss: 0.0491 - val_accuracy: 0.9853 - val_precision_at_recall: 0.9959 - 16s/epoch - 17ms/step
[1666721523161525/train/3 (pid 53376)] Epoch 3/5
[1666721523161525/train/5 (pid 53378)] 938/938 - 34s - loss: 0.2178 - accuracy: 0.9345 - precision_at_recall: 0.8700 - val_loss: 0.0588 - val_accuracy: 0.9814 - val_precision_at_recall: 0.9943 - 34s/epoch - 36ms/step
[1666721523161525/train/5 (pid 53378)] Epoch 2/5
[1666721523161525/train/2 (pid 53375)] Epoch 3/5
[1666721523161525/train/2 (pid 53375)] 938/938 - 12s - loss: 0.0940 - accuracy: 0.9701 - precision_at_recall: 0.9793 - val_loss: 0.0467 - val_accuracy: 0.9846 - val_precision_at_recall: 0.9961 - 12s/epoch - 13ms/step
[1666721523161525/train/4 (pid 53377)] 938/938 - 22s - loss: 0.0894 - accuracy: 0.9721 - precision_at_recall: 0.9825 - val_loss: 0.0437 - val_accuracy: 0.9854 - val_precision_at_recall: 0.9962 - 22s/epoch - 24ms/step
[1666721523161525/train/4 (pid 53377)] Epoch 3/5
[1666721523161525/train/3 (pid 53376)] 938/938 - 15s - loss: 0.0732 - accuracy: 0.9777 - precision_at_recall: 0.9893 - val_loss: 0.0382 - val_accuracy: 0.9873 - val_precision_at_recall: 0.9976 - 15s/epoch - 16ms/step
[1666721523161525/train/3 (pid 53376)] Epoch 4/5
[1666721523161525/train/6 (pid 53379)] 938/938 - 51s - loss: 0.1926 - accuracy: 0.9420 - precision_at_recall: 0.8998 - val_loss: 0.0510 - val_accuracy: 0.9833 - val_precision_at_recall: 0.9963 - 51s/epoch - 55ms/step
[1666721523161525/train/6 (pid 53379)] Epoch 2/5
[1666721523161525/train/2 (pid 53375)] Epoch 4/5
[1666721523161525/train/2 (pid 53375)] 938/938 - 12s - loss: 0.0836 - accuracy: 0.9743 - precision_at_recall: 0.9852 - val_loss: 0.0378 - val_accuracy: 0.9869 - val_precision_at_recall: 0.9974 - 12s/epoch - 13ms/step
[1666721523161525/train/3 (pid 53376)] 938/938 - 14s - loss: 0.0629 - accuracy: 0.9803 - precision_at_recall: 0.9920 - val_loss: 0.0353 - val_accuracy: 0.9877 - val_precision_at_recall: 0.9981 - 14s/epoch - 15ms/step
[1666721523161525/train/3 (pid 53376)] Epoch 5/5
[1666721523161525/train/2 (pid 53375)] Epoch 5/5
[1666721523161525/train/2 (pid 53375)] 938/938 - 11s - loss: 0.0716 - accuracy: 0.9780 - precision_at_recall: 0.9898 - val_loss: 0.0362 - val_accuracy: 0.9878 - val_precision_at_recall: 0.9977 - 11s/epoch - 12ms/step
[1666721523161525/train/5 (pid 53378)] 938/938 - 29s - loss: 0.0776 - accuracy: 0.9752 - precision_at_recall: 0.9874 - val_loss: 0.0408 - val_accuracy: 0.9861 - val_precision_at_recall: 0.9969 - 29s/epoch - 31ms/step
[1666721523161525/train/5 (pid 53378)] Epoch 3/5
[1666721523161525/train/2 (pid 53375)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/train/2 (pid 53375)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/train/2 (pid 53375)] Task finished successfully.
[1666721523161525/train/4 (pid 53377)] 938/938 - 21s - loss: 0.0686 - accuracy: 0.9790 - precision_at_recall: 0.9907 - val_loss: 0.0375 - val_accuracy: 0.9875 - val_precision_at_recall: 0.9977 - 21s/epoch - 22ms/step
[1666721523161525/train/4 (pid 53377)] Epoch 4/5
[1666721523161525/train/3 (pid 53376)] 938/938 - 12s - loss: 0.0556 - accuracy: 0.9833 - precision_at_recall: 0.9940 - val_loss: 0.0313 - val_accuracy: 0.9896 - val_precision_at_recall: 0.9979 - 12s/epoch - 13ms/step
[1666721523161525/train/3 (pid 53376)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/train/3 (pid 53376)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/train/3 (pid 53376)] Task finished successfully.
[1666721523161525/train/4 (pid 53377)] 938/938 - 16s - loss: 0.0577 - accuracy: 0.9815 - precision_at_recall: 0.9932 - val_loss: 0.0332 - val_accuracy: 0.9889 - val_precision_at_recall: 0.9980 - 16s/epoch - 17ms/step
[1666721523161525/train/4 (pid 53377)] Epoch 5/5
[1666721523161525/train/5 (pid 53378)] 938/938 - 21s - loss: 0.0593 - accuracy: 0.9814 - precision_at_recall: 0.9933 - val_loss: 0.0362 - val_accuracy: 0.9880 - val_precision_at_recall: 0.9970 - 21s/epoch - 23ms/step
[1666721523161525/train/5 (pid 53378)] Epoch 4/5
[1666721523161525/train/6 (pid 53379)] 938/938 - 36s - loss: 0.0682 - accuracy: 0.9793 - precision_at_recall: 0.9908 - val_loss: 0.0375 - val_accuracy: 0.9882 - val_precision_at_recall: 0.9979 - 36s/epoch - 38ms/step
[1666721523161525/train/6 (pid 53379)] Epoch 3/5
[1666721523161525/train/4 (pid 53377)] 938/938 - 18s - loss: 0.0508 - accuracy: 0.9840 - precision_at_recall: 0.9952 - val_loss: 0.0279 - val_accuracy: 0.9908 - val_precision_at_recall: 0.9988 - 18s/epoch - 19ms/step
[1666721523161525/train/4 (pid 53377)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/train/5 (pid 53378)] 938/938 - 20s - loss: 0.0496 - accuracy: 0.9843 - precision_at_recall: 0.9952 - val_loss: 0.0324 - val_accuracy: 0.9893 - val_precision_at_recall: 0.9984 - 20s/epoch - 21ms/step
[1666721523161525/train/5 (pid 53378)] Epoch 5/5
[1666721523161525/train/4 (pid 53377)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/train/4 (pid 53377)] Task finished successfully.
[1666721523161525/train/6 (pid 53379)] 938/938 - 25s - loss: 0.0537 - accuracy: 0.9834 - precision_at_recall: 0.9944 - val_loss: 0.0306 - val_accuracy: 0.9900 - val_precision_at_recall: 0.9983 - 25s/epoch - 27ms/step
[1666721523161525/train/6 (pid 53379)] Epoch 4/5
[1666721523161525/train/5 (pid 53378)] 938/938 - 15s - loss: 0.0411 - accuracy: 0.9866 - precision_at_recall: 0.9969 - val_loss: 0.0310 - val_accuracy: 0.9899 - val_precision_at_recall: 0.9980 - 15s/epoch - 16ms/step
[1666721523161525/train/5 (pid 53378)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/train/5 (pid 53378)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/train/5 (pid 53378)] Task finished successfully.
[1666721523161525/train/6 (pid 53379)] 938/938 - 17s - loss: 0.0435 - accuracy: 0.9862 - precision_at_recall: 0.9963 - val_loss: 0.0300 - val_accuracy: 0.9906 - val_precision_at_recall: 0.9984 - 17s/epoch - 18ms/step
[1666721523161525/train/6 (pid 53379)] Epoch 5/5
[1666721523161525/train/6 (pid 53379)] 938/938 - 14s - loss: 0.0373 - accuracy: 0.9880 - precision_at_recall: 0.9974 - val_loss: 0.0270 - val_accuracy: 0.9910 - val_precision_at_recall: 0.9993 - 14s/epoch - 15ms/step
[1666721523161525/train/6 (pid 53379)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/train/6 (pid 53379)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/train/6 (pid 53379)] Task finished successfully.
[1666721523161525/gather_scores/7 (pid 53481)] Task is starting.
[1666721523161525/gather_scores/7 (pid 53481)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/gather_scores/7 (pid 53481)] Task finished successfully.
[1666721523161525/end/8 (pid 53487)] Task is starting.
[1666721523161525/end/8 (pid 53487)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/end/8 (pid 53487)] WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
[1666721523161525/end/8 (pid 53487)] Task finished successfully.
Done!

3Visualize Results

python tuning_flow.py card view gather_scores
    Metaflow 2.7.12 executing TuningFlow for user:eddie
Resolving card: TuningFlow/1666721523161525/gather_scores/7

In this lesson, you saw how to extend your model training flows to tune hyperparameters in parallel. Whether you are building an ML platform or a workflow targeting a single application, it is important to consider your exploration budget and how you will improve models through processes like hyperparameter tuning. For example, here you can see an example of using Metaflow with more complex tuning algorithms via Optuna. In the next lesson, you will interpret the results of all the models you have trained thus far. See you there!