Skip to main content

A small library for managing deep learning models, hyper parameters and datasets

Project description

Zookeeper

Azure DevOps builds Azure DevOps coverage PyPI - Python Version PyPI PyPI - License Code style: black Join the community on Spectrum

A small library for managing deep learning models, hyper parameters and datasets designed to make training deep learning models easy and reproducible.

Getting Started

Zookeeper allows you to build command line interfaces for training deep learning models with very little boiler plate using click and TensorFlow Datasets. It helps you structure your machine learning projects in a framework agnostic and effective way. Zookeeper is heavily inspired by Tensor2Tensor and Fairseq but is designed to be used as a library making it lightweight and very flexible.

Installation

pip install zookeeper
pip install colorama  # optional for colored console output

Registry

Zookeeper keeps track of data preprocessing, models and hyperparameters to allow you to reference them by name from the commandline.

Datasets and Preprocessing

TensorFlow Datasets provides many popular datasets that can be downloaded automatically. In the following we will use MNIST and define a default preprocessing for the images that scales the image to [0, 1] and uses one-hot encoding for the class labels:

import tensorflow as tf
from zookeeper import cli, build_train, HParams, registry, Preprocessing

class ImageClassification(Preprocessing):
    @property
    def kwargs(self):
        return {
            "input_shape": self.features["image"].shape,
            "num_classes": self.features["label"].num_classes,
        }

    def inputs(self, data):
        return tf.cast(data["image"], tf.float32)

    def outputs(self, data):
        return tf.one_hot(data["label"], self.features["label"].num_classes)


@registry.register_preprocess("mnist")
class default(ImageClassification):
    def inputs(self, data):
        return super().inputs(data) / 255

Models

Next we will register a model called cnn. We will use the Keras API for this:

@registry.register_model
def cnn(hp, input_shape, num_classes):
    return tf.keras.models.Sequential(
        [
            tf.keras.layers.Conv2D(
                hp.filters[0], (3, 3), activation=hp.activation, input_shape=input_shape
            ),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(hp.filters[1], (3, 3), activation=hp.activation),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(hp.filters[2], (3, 3), activation=hp.activation),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(hp.filters[3], activation=hp.activation),
            tf.keras.layers.Dense(num_classes, activation="softmax"),
        ]
    )

Hyperparameters

For each model we can register one or more hyperparameters sets that will be passed to the model function when called:

@registry.register_hparams(cnn)
class basic(HParams):
    activation = "relu"
    batch_size = 32
    filters = [64, 64, 64, 64]
    learning_rate = 1e-3

    @property
    def optimizer(self):
        return tf.keras.optimizers.Adam(self.learning_rate)

Training loop

To train the models registered above we will need to write a custom training loop. Zookeeper will then tie everything together:

@cli.command()
@build_train()
def train(build_model, dataset, hparams, output_dir):
    """Start model training."""
    model = build_model(hparams, **dataset.preprocessing.kwargs)
    model.compile(
        optimizer=hparams.optimizer,
        loss="categorical_crossentropy",
        metrics=["categorical_accuracy", "top_k_categorical_accuracy"],
    )

    model.fit(
        dataset.train_data(hparams.batch_size),
        steps_per_epoch=dataset.train_examples // hparams.batch_size,
        validation_data=dataset.validation_data(hparams.batch_size),
        validation_steps=dataset.validation_examples // hparams.batch_size,
    )

This will register Click command called train which can be executed from the command line.

Command Line Interface

To make the file we just created executable we will add the following lines at the bottom:

if __name__ == "__main__":
    cli()

If you want to register your models in separate files, make sure to import them before calling cli to allow zookeeper to properly register them. To install your CLI as a executable command checkout the setuptools integration of Click.

Usage

Zookeeper already ships with prepare, plot, and tensorboard commands, but now also includes the train command we created above:

python examples/train.py --help
Usage: train.py [OPTIONS] COMMAND [ARGS]...

Options:
  --help  Show this message and exit.

Commands:
  install-completion  Install shell completion.
  plot                Plot data examples.
  prepare             Downloads and prepares datasets for reading.
  tensorboard         Start TensorBoard to monitor model training.
  train               Start model training.

To train the model we just registered run:

python examples/train.py train cnn --dataset mnist --hparams-set basic --hparams batch_size=64

Multiple arguments are seperated by a comma, and strings should be passed without quotion marks:

python examples/train.py train cnn --dataset mnist --hparams-set basic --hparams batch_size=32,actvation=relu

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zookeeper-0.5.3.tar.gz (14.8 kB view details)

Uploaded Source

Built Distribution

zookeeper-0.5.3-py3-none-any.whl (19.3 kB view details)

Uploaded Python 3

File details

Details for the file zookeeper-0.5.3.tar.gz.

File metadata

  • Download URL: zookeeper-0.5.3.tar.gz
  • Upload date:
  • Size: 14.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.8.0 tqdm/4.32.2 CPython/3.7.4

File hashes

Hashes for zookeeper-0.5.3.tar.gz
Algorithm Hash digest
SHA256 d032ba361f06d55031945febbe276f6531c75e433f6f7642e69db6a7ff3b76e8
MD5 2b26b1971e1988265d8b785eec4e1731
BLAKE2b-256 3506188648562ad6fd0ea803a2f85ba966377aba2c22dbd92e35a05831d435b0

See more details on using hashes here.

File details

Details for the file zookeeper-0.5.3-py3-none-any.whl.

File metadata

  • Download URL: zookeeper-0.5.3-py3-none-any.whl
  • Upload date:
  • Size: 19.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.8.0 tqdm/4.32.2 CPython/3.7.4

File hashes

Hashes for zookeeper-0.5.3-py3-none-any.whl
Algorithm Hash digest
SHA256 e591b36af0f1453139b7b2d1a2ab49608aaad8d978611e08e75c1a93c9aa6f62
MD5 313d709fde3016674805b7ca0c3535af
BLAKE2b-256 0c4863bd127d72dd7ab2ec8fb243673c948146f3a29129db6abd9f9841917caa

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page