Skip to main content

A flexible trainer interface for Jax and Haiku.

Project description

Bax

Bax, short for "boilerplate jax", is a small library that provides a flexible trainer interface for Jax.

Bax is rather strongly opinionated in a few ways. First, it is designed for use with the Haiku neural network library and is not compatible with e.g. Flax. Second, Bax assumes that data will be provided as a tf.data.Dataset. The goal of this library is not to be widely compatible and high-level (like Elegy).

If you are okay with making the above assumptions, then Bax will hopefully make your life much easier by implementing the boilerplate code involved in neural network training loops.

Please note that this library has not yet been extensively tested.

Installation

You can install Bax via pip:

pip install git+https://github.com/rystrauss/bax

Usage

Below are some simple examples that illustrate how to use Bax.

MNIST Classification

import optax
import tensorflow_datasets as tfds
import haiku as hk
import jax.numpy as jnp
import jax

from bax.trainer import Trainer


# Use TensorFlow Datasets to get our MNIST data.
ds = tfds.load("mnist", split="train").batch(32, drop_remainder=True)

# The loss function that we want to minimize.
def loss_fn(_, batch):
    model = hk.Sequential([hk.Flatten(), hk.nets.MLP([128, 128, 10])])

    preds = model(batch["image"] / 255.0)
    labels = jax.nn.one_hot(batch["label"], 10)

    loss = jnp.mean(optax.softmax_cross_entropy(preds, labels))
    accuracy = jnp.mean(jnp.argmax(preds, axis=-1) == batch["label"])

    # The first returned value is the loss, which is what will be minimized by the
    # trainer. The second value is a dictionary that can contain other metrics you
    # might be interested in (or, it can just be empty).
    return loss, {"accuracy": accuracy}

trainer = Trainer(loss=loss_fn, optimizer=optax.adam(0.001))

# You should see the loss and accuracy be displayed during training.
trainer.fit(ds, 10000)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bax-0.1.8.tar.gz (9.8 kB view details)

Uploaded Source

Built Distribution

bax-0.1.8-py3-none-any.whl (9.6 kB view details)

Uploaded Python 3

File details

Details for the file bax-0.1.8.tar.gz.

File metadata

  • Download URL: bax-0.1.8.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.10

File hashes

Hashes for bax-0.1.8.tar.gz
Algorithm Hash digest
SHA256 ebc7e830d89eb96e1271e3752239fdc2a85ed77f4858414273f29ef9a8f48261
MD5 d0288a851505d5916af69f8a40623f50
BLAKE2b-256 d382c5a6ea093109c6aae3dd36f5ed572c938d676de90ddb473de299b2472f39

See more details on using hashes here.

File details

Details for the file bax-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: bax-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 9.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.10

File hashes

Hashes for bax-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 17de34c5321770ff00463ab22a2e69e2eca5aa7ea21201b933fd8bb1c248f805
MD5 6bb25b22682a09247eb982ce3f06f0a3
BLAKE2b-256 c638e40dacae20df9e394c1e3bbe6a6c27e461f10e2595eb0f4c6184cef41ccf

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page