Skip to main content

A high-level API for building and training Flax NNX models.

Project description

blaxbird [blækbɜːd]

About

A high-level API to build and train NNX models.

Define the module

import optax
from flax import nnx

class CNN(nnx.Module):
  ...

def loss_fn(model, images, labels):
  logits = model(images)
  return optax.losses.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=labels
  ).mean()


def train_step(model, rng_key, batch, **kwargs):
    return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])


def val_step(model, rng_key, batch, **kwargs):
    return loss_fn(model, batch["image"], batch["label"])

Define the trainer

from jax import random as jr
from flax import nnx

from blaxbird import train_fn

model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
optimizer = get_optimizer(model)

train = train_fn(
  fns=(train_step, val_step),
  n_steps=n_steps,
  n_eval_frequency=n_eval_frequency,
  n_eval_batches=n_eval_batches,
)
train(jr.key(2), model, optimizer, train_itr, val_itr)

Author

Simon Dirmeier simd@mailbox.org

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blaxbird-0.0.1.tar.gz (211.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

blaxbird-0.0.1-py3-none-any.whl (12.7 kB view details)

Uploaded Python 3

File details

Details for the file blaxbird-0.0.1.tar.gz.

File metadata

  • Download URL: blaxbird-0.0.1.tar.gz
  • Upload date:
  • Size: 211.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.9

File hashes

Hashes for blaxbird-0.0.1.tar.gz
Algorithm Hash digest
SHA256 21b949d3991b54799748304f7d2772d7911079627f3997079d327d9efa41c377
MD5 a8d5c2d7b1eaeb9bfac3d3db44549cf5
BLAKE2b-256 c87ca2acdf91155c28180ba8dacda3e573908947131fb1b93a7951a0c7a7e3c6

See more details on using hashes here.

File details

Details for the file blaxbird-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: blaxbird-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 12.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.9

File hashes

Hashes for blaxbird-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 807352aa0d86bc0631397b75855c373bc6a0995f5a0806e84fa993f071db5eae
MD5 a38131979d262fd746ea6b7df46ea999
BLAKE2b-256 e1579014cddfcf0cb6b881d7eccbfb3918bb362d8361f575df0f9440402032ba

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page