Skip to main content

A flexible trainer interface for Jax and Haiku.

Project description

Bax

Bax, short for "boilerplate jax", is a small library that provides a flexible trainer interface for Jax.

Bax is rather strongly opinionated in a few ways. First, it is designed for use with the Haiku neural network library and is not compatible with e.g. Flax. Second, Bax assumes that data will be provided as a tf.data.Dataset. The goal of this library is not to be widely compatible and high-level (like Elegy).

If you are okay with making the above assumptions, then Bax will hopefully make your life much easier by implementing the boilerplate code involved in neural network training loops.

Please note that this library has not yet been extensively tested.

Installation

You can install Bax via pip:

pip install git+https://github.com/rystrauss/bax

Usage

Below are some simple examples that illustrate how to use Bax.

MNIST Classification

import optax
import tensorflow_datasets as tfds
import haiku as hk
import jax.numpy as jnp
import jax

from bax.trainer import Trainer


# Use TensorFlow Datasets to get our MNIST data.
ds = tfds.load("mnist", split="train").batch(32, drop_remainder=True)

# The loss function that we want to minimize.
def loss_fn(_, batch):
    model = hk.Sequential([hk.Flatten(), hk.nets.MLP([128, 128, 10])])

    preds = model(batch["image"] / 255.0)
    labels = jax.nn.one_hot(batch["label"], 10)

    loss = jnp.mean(optax.softmax_cross_entropy(preds, labels))
    accuracy = jnp.mean(jnp.argmax(preds, axis=-1) == batch["label"])

    # The first returned value is the loss, which is what will be minimized by the
    # trainer. The second value is a dictionary that can contain other metrics you
    # might be interested in (or, it can just be empty).
    return loss, {"accuracy": accuracy}

trainer = Trainer(loss=loss_fn, optimizer=optax.adam(0.001))

# You should see the loss and accuracy be displayed during training.
trainer.fit(ds, 10000)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bax-0.1.7.tar.gz (9.8 kB view details)

Uploaded Source

Built Distribution

bax-0.1.7-py3-none-any.whl (9.6 kB view details)

Uploaded Python 3

File details

Details for the file bax-0.1.7.tar.gz.

File metadata

  • Download URL: bax-0.1.7.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.10

File hashes

Hashes for bax-0.1.7.tar.gz
Algorithm Hash digest
SHA256 9b13c070f67511f65fa879982955bfedc1e621a9339bbb848a3cb2d676eab182
MD5 2bff2f42c4fa76c7704d0f100809c836
BLAKE2b-256 3a69ff9da879d5a0ede667c20425f14057e6ac23d93603c6f0a9720ac11d2ca1

See more details on using hashes here.

File details

Details for the file bax-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: bax-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 9.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.10

File hashes

Hashes for bax-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 0628c1eda33f2722efd0565caffe59ad0262a296b2ba689ea8f52ed611541804
MD5 e85d093662be3e845bc0cf696be25848
BLAKE2b-256 a8a85e7d3029413f0e97f2284abd62b060243630ae07b60acce87ee88d4dd4ec

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page