Flexible Neural Networks in JAX
Project description
FlareJax
FlareJax is a Python library for building neural networks and optimizers in Jax. It is designed to minimize the time between a new research idea and its implementation.
Features
- Mutable modules for quick and dirty modifications via
Module
- Serialization of modules via
@saveable
,save
, andload
- Systematically modifying modules by using
flatten
andunflatten
- Safely handling shared/cyclical references and static arguments through
filter_jit
- Commonly used NN layers and optimizers are included
- As a small codebase, it is relatively easy to understand and extend
Quick Example
Define new Modules by subclassing Module
. All methods are callable PyTrees.
@flr.saveable("Example:Linear") # optional, make saveable
class Linear(flr.Module):
def __init__(self, key: PRNGKeyArray, dim: int):
self.dim = dim
self.w = None
def __call__(self, x):
# lazy initialization dependent on the input shape
if self.w is None:
self.w = jrn.normal(key, (x.shape[-1], self.dim))
return x @ self.w
layer = Linear(jrn.PRNGKey(0), 3)
x = jnp.zeros((1, 4))
# the model is initialized after the first call
y = layer(x)
assert layer.w.shape == (4, 3)
For optimization, define a loss function, which takes the module as the first argument.
def loss_fn(module, x, y):
return jnp.mean((module(x) - y) ** 2)
opt = flr.opt.Adam(3e-4)
# automatically just-in-time compiled
opt, model, loss = flr.train(opt, model, loss_fn, x, y)
Models can be saved and loaded.
flr.save(layer, "model.npz")
# load the model
layer = flr.load("model.npz")
assert isinstance(layer, Linear)
Installation
FlareJax can be installed via pip. It requires Python 3.10 or higher and Jax 0.4.33 or higher.
pip install flarejax
See Also
- Jax Docs: Jax is a library for numerical computing that is designed to be composable and fast.
- Equinox library: FlareJax is heavily inspired by this awesome library.
- torch.nn.Module: Many of the principles of mutability are inspired by PyTorch's
torch.nn.Module
. - NNX Docs: NNX is a library for neural networks in flax that also supports mutability.
- Always feel free to reach out to me via email.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
flarejax-0.4.10.tar.gz
(20.1 kB
view details)
Built Distribution
flarejax-0.4.10-py3-none-any.whl
(24.5 kB
view details)
File details
Details for the file flarejax-0.4.10.tar.gz
.
File metadata
- Download URL: flarejax-0.4.10.tar.gz
- Upload date:
- Size: 20.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 251e8a70a9460736045a1d67b9b6d43b21b8b3760d00660619232f0d99238200 |
|
MD5 | 9b0106a1b8973ac5121b6be6cd360a3a |
|
BLAKE2b-256 | bb6697f097cdf123c888679d4e7d12ba7840a5aed5cd09a2173826c8b844db0c |
File details
Details for the file flarejax-0.4.10-py3-none-any.whl
.
File metadata
- Download URL: flarejax-0.4.10-py3-none-any.whl
- Upload date:
- Size: 24.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 83fc293e2ea8c103894df4d26904b73583ba766ac7badcd1fcf9c411b3b8ff84 |
|
MD5 | b30b366245e596af0955999f3c4a628f |
|
BLAKE2b-256 | b1c8327bed59c8642f0bf2e0d25cd09a9e3b1ee804e0881c7b4fa18099d35527 |