Flexible Neural Networks in JAX
Project description
FlareJax
FlareJax is a Python library for building neural networks and optimizers in Jax. It is designed to minimize the time between a new research idea and its implementation.
Features
- Mutable modules for quick and dirty modifications via
Module
- Serialization of modules via
@saveable
,save
, andload
- Systematically modifying modules by using
flatten
andunflatten
- Safely handling shared/cyclical references and static arguments through
filter_jit
- Commonly used NN layers and optimizers are included
- As a small codebase, it is relatively easy to understand and extend
Quick Example
Define new Modules by subclassing Module
. All methods are callable PyTrees.
@flr.saveable("Example:Linear") # optional, make saveable
class Linear(flr.Module):
def __init__(self, key: PRNGKeyArray, dim: int):
self.dim = dim
self.w = None
def __call__(self, x):
# lazy initialization dependent on the input shape
if self.w is None:
self.w = jrn.normal(key, (x.shape[-1], self.dim))
return x @ self.w
layer = Linear(jrn.PRNGKey(0), 3)
x = jnp.zeros((1, 4))
# the model is initialized after the first call
y = layer(x)
assert layer.w.shape == (4, 3)
For optimization, define a loss function, which takes the module as the first argument.
def loss_fn(module, x, y):
return jnp.mean((module(x) - y) ** 2)
opt = flr.opt.Adam(3e-4)
# automatically just-in-time compiled
opt, model, loss = flr.train(opt, model, loss_fn, x, y)
Models can be saved and loaded.
flr.save(layer, "model.npz")
# load the model
layer = flr.load("model.npz")
assert isinstance(layer, Linear)
Installation
FlareJax can be installed via pip. It requires Python 3.10 or higher and Jax 0.4.33 or higher.
pip install flarejax
See Also
- Jax Docs: Jax is a library for numerical computing that is designed to be composable and fast.
- Equinox library: FlareJax is heavily inspired by this awesome library.
- torch.nn.Module: Many of the principles of mutability are inspired by PyTorch's
torch.nn.Module
. - NNX Docs: NNX is a library for neural networks in flax that also supports mutability.
- Always feel free to reach out to me via email.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
flarejax-0.4.9.tar.gz
(20.1 kB
view details)
Built Distribution
flarejax-0.4.9-py3-none-any.whl
(24.5 kB
view details)
File details
Details for the file flarejax-0.4.9.tar.gz
.
File metadata
- Download URL: flarejax-0.4.9.tar.gz
- Upload date:
- Size: 20.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7c4d8a91677d97c9796f8d9518773ad0824e72c3dd24ade411cf8f4c78938516 |
|
MD5 | 634fbb0321ab80aab27c0acda70ecc42 |
|
BLAKE2b-256 | 4133e173d21f8f2bfcca6b0afd5091770698f92b68069dcfe65ecff5cec81e35 |
File details
Details for the file flarejax-0.4.9-py3-none-any.whl
.
File metadata
- Download URL: flarejax-0.4.9-py3-none-any.whl
- Upload date:
- Size: 24.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f4dfe078c1a7b82beb8a97154eb85f9906e493c3dc9bacc0df5ba1195479fd49 |
|
MD5 | ac09435ece8031bac8a14d188f981b1d |
|
BLAKE2b-256 | 310c88db23a63385a37429c24a6a6a462c744eeef2b127136e9e100670ce32fd |