Skip to main content

Keras Neural Network Flow

In this episode, you will use Keras to build a neural network in a Metaflow flow.

1Write a Neural Network Flow

The flow shows:

  • The start step loads image data from Keras.
  • The build_model step builds and compiles a Keras model.
  • The train step fits the neural net.

neural_net_flow.py
from metaflow import FlowSpec, step, Parameter

class NeuralNetFlow(FlowSpec):

epochs = Parameter('e', default=10)

@step
def start(self):
import numpy as np
from tensorflow import keras
self.num_classes = 10
((x_train, y_train),
(x_test, y_test)) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
self.x_train = np.expand_dims(x_train, -1)
self.x_test = np.expand_dims(x_test, -1)
self.y_train = keras.utils.to_categorical(
y_train, self.num_classes)
self.y_test = keras.utils.to_categorical(
y_test, self.num_classes)
self.next(self.build_model)

@step
def build_model(self):
import tempfile
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers # pylint: disable=import-error
self.model = keras.Sequential([
keras.Input(shape=(28,28,1)),
layers.Conv2D(32, kernel_size=(3, 3),
activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3),
activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(self.num_classes, activation="softmax")
])
self.model.compile(loss="categorical_crossentropy",
optimizer="adam", metrics=["accuracy"])
self.next(self.train)

@step
def train(self):
import tempfile
import tensorflow as tf
self.batch_size = 128
self.model.fit(
self.x_train, self.y_train,
batch_size=self.batch_size,
epochs=self.epochs, validation_split=0.1
)
self.next(self.end)

@step
def end(self):
print("NeuralNetFlow is all done.")

if __name__ == "__main__":
NeuralNetFlow()

2Run the Flow

python neural_net_flow.py run
     Workflow starting (run-id 1666720917686061):
[1666720917686061/start/1 (pid 52844)] Task is starting.
[1666720917686061/start/1 (pid 52844)] Task finished successfully.
[1666720917686061/build_model/2 (pid 52862)] Task is starting.
[1666720917686061/build_model/2 (pid 52862)] Task finished successfully.
[1666720917686061/train/3 (pid 52873)] Task is starting.
422/422 [==============================] - 8s 18ms/step - loss: 0.3670 - accuracy: 0.8887 - val_loss: 0.0776 - val_accuracy: 0.98004 - loss: 2.2965 - accuracy: 0.13
422/422 [==============================] - 8s 19ms/step - loss: 0.1067 - accuracy: 0.9674 - val_loss: 0.0540 - val_accuracy: 0.9855- loss: 0.1289 - accuracy: 0.95
422/422 [==============================] - 8s 18ms/step - loss: 0.0811 - accuracy: 0.9757 - val_loss: 0.0464 - val_accuracy: 0.9867- loss: 0.1024 - accuracy: 0.9
422/422 [==============================] - 8s 18ms/step - loss: 0.0660 - accuracy: 0.9789 - val_loss: 0.0450 - val_accuracy: 0.9872- loss: 0.1125 - accuracy: 0.96
422/422 [==============================] - 8s 18ms/step - loss: 0.0585 - accuracy: 0.9820 - val_loss: 0.0363 - val_accuracy: 0.9895- loss: 0.0410 - accuracy: 0.97
422/422 [==============================] - 7s 18ms/step - loss: 0.0530 - accuracy: 0.9832 - val_loss: 0.0360 - val_accuracy: 0.9910- loss: 0.0488 - accuracy: 0.98
422/422 [==============================] - 8s 18ms/step - loss: 0.0468 - accuracy: 0.9853 - val_loss: 0.0337 - val_accuracy: 0.9918- loss: 0.0540 - accuracy: 0.96
422/422 [==============================] - 8s 19ms/step - loss: 0.0433 - accuracy: 0.9861 - val_loss: 0.0324 - val_accuracy: 0.9910- loss: 0.0201 - accuracy: 1.00
422/422 [==============================] - 9s 21ms/step - loss: 0.0417 - accuracy: 0.9868 - val_loss: 0.0353 - val_accuracy: 0.9907- loss: 0.0363 - accuracy: 0.98
422/422 [==============================] - 8s 20ms/step - loss: 0.0396 - accuracy: 0.9877 - val_loss: 0.0335 - val_accuracy: 0.9910- loss: 0.0416 - accuracy: 0.99
[1666720917686061/train/3 (pid 52873)] Task finished successfully.
[1666720917686061/end/4 (pid 52933)] Task is starting.
[1666720917686061/end/4 (pid 52933)] NeuralNetFlow is all done.
[1666720917686061/end/4 (pid 52933)] Task finished successfully.
Done!

In this episode, you saw how to train and evaluate a neural network. Here are more examples of using Metaflow with neural networks:

In the next episode, you will see how to use Metaflow's cards feature to add data visualization to this flow. See you there!