A high-level API for building and training Flax NNX models.
Project description
blaxbird [blækbɜːd]
A high-level API to build and train NNX models
About
Blaxbird [blækbɜːd] is a high-level API to easily build NNX models and train them on CPU or GPU.
Using blaxbird one can
- concisely define models and loss functions without the usual JAX/Flax verbosity,
- easily define checkpointers that save the best and most current network weights,
- distribute data and model weights over multiple processes or GPUs,
- define hooks that are periodically called during training.
In addition, blaxbird offers high-quality implementation of common neural network modules and algorithms, such as:
- MLP, Diffusion Transformer,
- Flow Matching and Denoising Score Matching (EDM schedules) with Euler and Heun samplers,
- Consistency Distillation/Matching.
Example
To use blaxbird, one only needs to define a model, a loss function, and train and validation step functions:
import optax
from flax import nnx
class CNN(nnx.Module):
...
def loss_fn(model, images, labels):
logits = model(images)
return optax.losses.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=labels
).mean()
def train_step(model, rng_key, batch):
return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])
def val_step(model, rng_key, batch):
return loss_fn(model, batch["image"], batch["label"])
You can then define construct (and use) a training function like this:
import optax
from flax import nnx
from jax import random as jr
from blaxbird import train_fn
model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
train = train_fn(
fns=(train_step, val_step),
n_steps=100,
eval_every_n_steps=10,
n_eval_batches=10
)
train(jr.key(2), model, optimizer, train_itr, val_itr)
See the entire self-contained example in examples/mnist_classification.
Usage
train_fn is a higher order function with the following signature:
def train_fn(
*,
fns: tuple[Callable, Callable],
shardings: Optional[tuple[jax.NamedSharding, jax.NamedSharding]] = None,
n_steps: int,
eval_every_n_steps: int,
n_eval_batches: int,
log_to_wandb: bool = False,
hooks: Iterable[Callable] = (),
) -> Callable:
...
We briefly explain the more ambiguous argument types below.
fns
fns is a required argument consistenf of tuple of two functions, a step function and a validation function.
In the simplest case they look like this:
def train_step(model, rng_key, batch):
return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])
def val_step(model, rng_key, batch):
return loss_fn(model, batch["image"], batch["label"])
Both train_step and val_step have the same arguments and argument types:
modelspecifies annx.Module, i.e., a neural network like the CNN shown above.rng_keyis ajax.random.keyin case you need to generate random numbers.batchis a sample from a data loader (to be specified later).
The loss function that is called by both computes a scalar loss value. B
While train_step returns has to return the loss and gradients, val_step only needs
to return the loss.
shardings
To specify how data and model weights are distributed over devices and processes,
blaxbird uses JAX' sharding functionality.
shardings is again specified by a tuple, one for the model sharding, the other for the data sharding.
An example is shown below, where we only distributed the data over num_devices devices.
You can, if you don't want to distribute anything, just set the argument to None or not specify it.
def get_sharding():
num_devices = jax.local_device_count()
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((num_devices,)), ("data",)
)
model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec())
data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))
return model_sharding, data_sharding
hooks
hooks is a list of callables which are periodically called during training.
Each hook has to have the following signature:
def hook_fn(step, *, model, **kwargs) -> None:
...
It takes an integer step specifying the current training iteration and the model itself.
For instance, if you want to track custom metrics during validation, you could create a hook like this:
def hook_fn(metrics, val_iter, hook_every_n_steps):
def fn(step, *, model, **kwargs):
if step % hook_every_n_steps != 0:
return
for batch in val_iter:
logits = model(batch["image"])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch["label"]
).mean()
metrics.update(loss=loss, logits=logits, labels=batch["label"])
if jax.process_index() == 0:
curr_metrics = ", ".join(
[f"{k}: {v}" for k, v in metrics.compute().items()]
)
logging.info(f"metrics at step {step}: {curr_metrics}")
metrics.reset()
return fn
metrics = nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(),
loss=nnx.metrics.Average("loss"),
)
hook = hook_fn(metrics, val_iter, hook_every_n_steps)
This creates a hook function hook that after eval_every_n_steps steps iterates over the validation set
computes accuracy and loss, and then logs everything.
To provide multiple hooks to the train function, just concatenate them in a list.
A checkpointing hook
We provide a convenient hook for checkpointing which can be constructed using
get_default_checkpointer. The checkpointer saves both the last k checkpoints with the lowest
validation loss and the last training checkpoint.
The signature of the hook is:
def get_default_checkpointer(
outfolder: str,
*,
save_every_n_steps: int,
max_to_keep: int = 5,
) -> tuple[Callable, Callable, Callable]
Its arguments are:
outfolder: a folder specifying where to store the checkpoints.save_every_n_steps: after how many training steps to store a checkpoint.max_to_keep: the number of checkpoints to keep before starting to remove old checkpoints (to not clog the device).
For instance, you would construct the checkpointing function then like this:
from blaxbird import get_default_checkpointer
hook_save, *_ = get_default_checkpointer(
"checkpoints", save_every_n_steps=100
)
Restoring a run
You can also use get_default_checkpointer to restart the run where you left off.
get_default_checkpointer in fact returns three functions, one for saving checkpoints and two for restoring
checkpoints:
from blaxbird import get_default_checkpointer
save, restore_best, restore_last = get_default_checkpointer(
"checkpoints", save_every_n_steps=100
)
You can then do either of:
model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
model, optimizer = restore_best(model, optimizer)
model, optimizer = restore_last(model, optimizer)
Doing training
After having defined train functions, hooks and shardings, you can train your model like this:
train = train_fn(
fns=(train_step, val_step),
n_steps=n_steps,
eval_every_n_steps=eval_every_n_steps,
n_eval_batches=n_eval_batches,
shardings=(model_sharding, data_sharding),
hooks=hooks,
log_to_wandb=False,
)
train(jr.key(1), model, optimizer, train_itr, val_itr)
Self-contained examples that also explain how the data loaders should look like can be found in examples.
Installation
To install the package from PyPI, call:
pip install blaxbird
To install the latest GitHub , just call the following on the command line:
pip install git+https://github.com/dirmeier/blaxbird@<RELEASE>
Author
Simon Dirmeier simd@mailbox.org
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file blaxbird-0.1.0.tar.gz.
File metadata
- Download URL: blaxbird-0.1.0.tar.gz
- Upload date:
- Size: 193.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f81a7eab81549d37bf5290749312002a964e3340f90b4579ff3ed7097cbc8758
|
|
| MD5 |
9d401127fd2dfbfdb4fbda6a118a18dc
|
|
| BLAKE2b-256 |
d4df1f4a847a43a6687c8b19bb49111a53c80eed00d5f815bc75874c702c2fdb
|
File details
Details for the file blaxbird-0.1.0-py3-none-any.whl.
File metadata
- Download URL: blaxbird-0.1.0-py3-none-any.whl
- Upload date:
- Size: 24.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4e66fae4cb45e608ef619de03e59e8092062b23727996b46c65853971add82cd
|
|
| MD5 |
5421c0daead642064c01358c07e39e1d
|
|
| BLAKE2b-256 |
5e37fd1c9e8de74d3b798d221d7dfaff29668a3df52d1e8d1311fc6286ed16ae
|