A high-level API for building and training Flax NNX models.
Project description
blaxbird [blækbɜːd]
About
A high-level API to build and train NNX models.
Define the module
import optax
from flax import nnx
class CNN(nnx.Module):
...
def loss_fn(model, images, labels):
logits = model(images)
return optax.losses.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=labels
).mean()
def train_step(model, rng_key, batch, **kwargs):
return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])
def val_step(model, rng_key, batch, **kwargs):
return loss_fn(model, batch["image"], batch["label"])
Define the trainer
from jax import random as jr
from flax import nnx
from blaxbird import train_fn
model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
optimizer = get_optimizer(model)
train = train_fn(
fns=(train_step, val_step),
n_steps=n_steps,
n_eval_frequency=n_eval_frequency,
n_eval_batches=n_eval_batches,
)
train(jr.key(2), model, optimizer, train_itr, val_itr)
Author
Simon Dirmeier simd@mailbox.org
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
blaxbird-0.0.1.tar.gz
(211.4 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
blaxbird-0.0.1-py3-none-any.whl
(12.7 kB
view details)
File details
Details for the file blaxbird-0.0.1.tar.gz.
File metadata
- Download URL: blaxbird-0.0.1.tar.gz
- Upload date:
- Size: 211.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
21b949d3991b54799748304f7d2772d7911079627f3997079d327d9efa41c377
|
|
| MD5 |
a8d5c2d7b1eaeb9bfac3d3db44549cf5
|
|
| BLAKE2b-256 |
c87ca2acdf91155c28180ba8dacda3e573908947131fb1b93a7951a0c7a7e3c6
|
File details
Details for the file blaxbird-0.0.1-py3-none-any.whl.
File metadata
- Download URL: blaxbird-0.0.1-py3-none-any.whl
- Upload date:
- Size: 12.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
807352aa0d86bc0631397b75855c373bc6a0995f5a0806e84fa993f071db5eae
|
|
| MD5 |
a38131979d262fd746ea6b7df46ea999
|
|
| BLAKE2b-256 |
e1579014cddfcf0cb6b881d7eccbfb3918bb362d8361f575df0f9440402032ba
|