A small library for managing deep learning models, hyper parameters and datasets
A small library for managing deep learning models, hyper parameters and datasets designed to make training deep learning models easy and reproducible.
Zookeeper allows you to build command line interfaces for training deep learning models with very little boiler plate using click and TensorFlow Datasets. It helps you structure your machine learning projects in a framework agnostic and effective way. Zookeeper is heavily inspired by Tensor2Tensor and Fairseq but is designed to be used as a library making it lightweight and very flexible. Currently zookeeper is limited to image classification tasks but we are working on making it useful for other tasks as well.
pip install zookeeper pip install colorama # optional for colored console output
Zookeeper keeps track of data preprocessing, models and hyperparameters to allow you to reference them by name from the commandline.
Datasets and Preprocessing
TensorFlow Datasets provides many popular datasets that can be downloaded automatically.
In the following we will use MNIST and define a
default preprocessing for the images that scales the image to
import tensorflow as tf from zookeeper import cli, build_train, HParams, registry @registry.register_preprocess("mnist") def default(image, training=False): return tf.cast(image, dtype=tf.float32) / 255
Next we will register a model called
cnn. We will use the Keras API for this:
@registry.register_model def cnn(hp, dataset): return tf.keras.models.Sequential( [ tf.keras.layers.Conv2D( hp.filters, (3, 3), activation=hp.activation, input_shape=dataset.input_shape, ), tf.keras.layers.MaxPooling2D((2, 2)), tf.keras.layers.Conv2D(hp.filters, (3, 3), activation=hp.activation), tf.keras.layers.MaxPooling2D((2, 2)), tf.keras.layers.Conv2D(hp.filters, (3, 3), activation=hp.activation), tf.keras.layers.Flatten(), tf.keras.layers.Dense(hp.filters, activation=hp.activation), tf.keras.layers.Dense(dataset.num_classes, activation="softmax"), ] )
For each model we can register one or more hyperparameters sets that will be passed to the model function when called:
@registry.register_hparams(cnn) class basic(HParams): activation = "relu" batch_size = 32 filters = [64, 64, 64, 64] learning_rate = 1e-3 @property def optimizer(self): return tf.keras.optimizers.Adam(self.learning_rate)
To train the models registered above we will need to write a custom training loop. Zookeeper will then tie everything together:
@cli.command() @build_train def train(build_model, dataset, hparams, output_dir, epochs): model = build_model(hparams, dataset) model.compile( optimizer=hparams.optimizer, loss="categorical_crossentropy", metrics=["categorical_accuracy", "top_k_categorical_accuracy"], ) model.fit( dataset.train_data(hparams.batch_size), epochs=epochs, steps_per_epoch=dataset.train_examples // hparams.batch_size, validation_data=dataset.validation_data(hparams.batch_size), validation_steps=dataset.validation_examples // hparams.batch_size, )
This will register Click command called
train which can be executed from the command line.
Command Line Interface
To make the file we just created executable we will add the following lines at the bottom:
if __name__ == "__main__": cli()
If you want to register your models in separate files, make sure to import them before calling
cli to allow zookeeper to properly register them. To install your CLI as a executable command checkout the
setuptools integration of Click.
Zookeeper already ships with
tensorboard commands, but now also includes the
train command we created above:
python examples/train.py --help
Usage: train.py [OPTIONS] COMMAND [ARGS]... Options: --help Show this message and exit. Commands: plot prepare tensorboard train
To train the model we just registered run:
python examples/train.py train cnn --dataset mnist --epochs 10 --hparams-set basic --hparams batch_size=64
Release history Release notifications | RSS feed
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
|Filename, size||File type||Python version||Upload date||Hashes|
|Filename, size zookeeper-0.1.1-py3-none-any.whl (17.0 kB)||File type Wheel||Python version py3||Upload date||Hashes View|
Hashes for zookeeper-0.1.1-py3-none-any.whl