A flexible trainer interface for Jax and Haiku.
Project description
Bax
Bax, short for "boilerplate jax", is a small library that provides a flexible trainer interface for Jax.
Bax is rather strongly opinionated in a few ways. First, it is designed for use with the
Haiku neural network library and is not
compatible with e.g. Flax. Second, Bax assumes that
data will be provided as a tf.data.Dataset
. The goal of this library is not to be
widely compatible and high-level (like Elegy).
If you are okay with making the above assumptions, then Bax will hopefully make your life much easier by implementing the boilerplate code involved in neural network training loops.
Please note that this library has not yet been extensively tested.
Installation
You can install Bax via pip:
pip install bax
Usage
Below are some simple examples that illustrate how to use Bax.
MNIST Classification
import optax
import tensorflow_datasets as tfds
import haiku as hk
import jax.numpy as jnp
import jax
from bax.trainer import Trainer
# Use TensorFlow Datasets to get our MNIST data.
train_ds = tfds.load("mnist", split="train").batch(32, drop_remainder=True)
test_ds = tfds.load("mnist", split="test").batch(32, drop_remainder=True)
# The loss function that we want to minimize.
def loss_fn(step, is_training, batch):
model = hk.Sequential([hk.Flatten(), hk.nets.MLP([128, 128, 10])])
preds = model(batch["image"] / 255.0)
labels = jax.nn.one_hot(batch["label"], 10)
loss = jnp.mean(optax.softmax_cross_entropy(preds, labels))
accuracy = jnp.mean(jnp.argmax(preds, axis=-1) == batch["label"])
# The first returned value is the loss, which is what will be minimized by the
# trainer. The second value is a dictionary that can contain other metrics you
# might be interested in (or, it can just be empty).
return loss, {"accuracy": accuracy}
trainer = Trainer(loss=loss_fn, optimizer=optax.adam(0.001))
# Run the training loop. Metrics will be printed out each time the validation
# dataset is evaluated (in this case, every 1000 steps).
trainer.fit(train_ds, steps=10000, val_dataset=test_ds, validation_freq=1000)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file bax-0.1.12.tar.gz
.
File metadata
- Download URL: bax-0.1.12.tar.gz
- Upload date:
- Size: 10.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d28e4d7ef8cc86773f08180f5f165fa25f12f9c7885cf5f28a94ade4932b25a6 |
|
MD5 | 8ef2408aa774cfb930c59ddfc6338f04 |
|
BLAKE2b-256 | 75a1ad573ed74542a0248e7eb8661c19ce1dab6bc882a60ef4960c9b3f9e7b8f |
File details
Details for the file bax-0.1.12-py3-none-any.whl
.
File metadata
- Download URL: bax-0.1.12-py3-none-any.whl
- Upload date:
- Size: 10.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 54742c6c1cbcc64b3bb6051b87aa79b04efbb90b332e2c2f83f6fe0d14bfdbd6 |
|
MD5 | 2a3f4a9c81861824ff6a9ea08ba98f0c |
|
BLAKE2b-256 | 827461457c4a693c713d73826c6aed27d07b16c97a9468c19cf3521c4611ae82 |