A small library for managing deep learning models, hyper parameters and datasets
Project description
Zookeeper
A small library for managing deep learning models, hyper parameters and datasets designed to make training deep learning models easy and reproducible.
Getting Started
Zookeeper allows you to build command line interfaces for training deep learning models with very little boiler plate using click and TensorFlow Datasets. It helps you structure your machine learning projects in a framework agnostic and effective way. Zookeeper is heavily inspired by Tensor2Tensor and Fairseq but is designed to be used as a library making it lightweight and very flexible. Currently zookeeper is limited to image classification tasks but we are working on making it useful for other tasks as well.
Installation
pip install zookeeper
pip install colorama # optional for colored console output
Registry
Zookeeper keeps track of data preprocessing, models and hyperparameters to allow you to reference them by name from the commandline.
Datasets and Preprocessing
TensorFlow Datasets provides many popular datasets that can be downloaded automatically.
In the following we will use MNIST and define a default
preprocessing for the images that scales the image to [0, 1]
:
import tensorflow as tf
from zookeeper import cli, build_train, HParams, registry
@registry.register_preprocess("mnist")
def default(image, training=False):
return tf.cast(image, dtype=tf.float32) / 255
Models
Next we will register a model called cnn
. We will use the Keras API for this:
@registry.register_model
def cnn(hp, dataset):
return tf.keras.models.Sequential(
[
tf.keras.layers.Conv2D(
hp.filters[0],
(3, 3),
activation=hp.activation,
input_shape=dataset.input_shape,
),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(hp.filters[1], (3, 3), activation=hp.activation),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(hp.filters[2], (3, 3), activation=hp.activation),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(hp.filters[3], activation=hp.activation),
tf.keras.layers.Dense(dataset.num_classes, activation="softmax"),
]
)
Hyperparameters
For each model we can register one or more hyperparameters sets that will be passed to the model function when called:
@registry.register_hparams(cnn)
class basic(HParams):
activation = "relu"
batch_size = 32
filters = [64, 64, 64, 64]
learning_rate = 1e-3
@property
def optimizer(self):
return tf.keras.optimizers.Adam(self.learning_rate)
Training loop
To train the models registered above we will need to write a custom training loop. Zookeeper will then tie everything together:
@cli.command()
@build_train()
def train(build_model, dataset, hparams, output_dir):
"""Start model training."""
model = build_model(hparams, dataset)
model.compile(
optimizer=hparams.optimizer,
loss="categorical_crossentropy",
metrics=["categorical_accuracy", "top_k_categorical_accuracy"],
)
model.fit(
dataset.train_data(hparams.batch_size),
steps_per_epoch=dataset.train_examples // hparams.batch_size,
validation_data=dataset.validation_data(hparams.batch_size),
validation_steps=dataset.validation_examples // hparams.batch_size,
)
This will register Click command called train
which can be executed from the command line.
Command Line Interface
To make the file we just created executable we will add the following lines at the bottom:
if __name__ == "__main__":
cli()
If you want to register your models in separate files, make sure to import them before calling cli
to allow zookeeper to properly register them. To install your CLI as a executable command checkout the setuptools
integration of Click.
Usage
Zookeeper already ships with prepare
, plot
, and tensorboard
commands, but now also includes the train
command we created above:
python examples/train.py --help
Usage: train.py [OPTIONS] COMMAND [ARGS]...
Options:
--help Show this message and exit.
Commands:
install-completion Install shell completion.
plot Plot data examples.
prepare Downloads and prepares datasets for reading.
tensorboard Start TensorBoard to monitor model training.
train Start model training.
To train the model we just registered run:
python examples/train.py train cnn --dataset mnist --hparams-set basic --hparams batch_size=64
Multiple arguments are seperated by a comma, and strings should be passed without quotion marks:
python examples/train.py train cnn --dataset mnist --hparams-set basic --hparams batch_size=32,actvation=relu
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file zookeeper-0.3.0.tar.gz
.
File metadata
- Download URL: zookeeper-0.3.0.tar.gz
- Upload date:
- Size: 13.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.8.0 tqdm/4.32.1 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9ded44aa4b5b3e45d89e796c85b02c449277d4c5a82e82893d7c8dd343d1d9b4 |
|
MD5 | 4506e175c23199223bb4f4ee291f23e4 |
|
BLAKE2b-256 | 79a335f655c5ce51a4000abaa496bfc3dd6fdc81bf1949445bbd14164fea2b12 |
File details
Details for the file zookeeper-0.3.0-py3-none-any.whl
.
File metadata
- Download URL: zookeeper-0.3.0-py3-none-any.whl
- Upload date:
- Size: 17.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.8.0 tqdm/4.32.1 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e111a922e2004def9be2367b75903e3a37ddfe75ed1a7a2211814e476a478f65 |
|
MD5 | 1c76ddea1df4da2dd8372499b8b5be88 |
|
BLAKE2b-256 | 83373cc67b1947514c5272523cda83e2f21ede4cab440b1e6761f9d613cca41e |