Flexible Neural Networks in JAX
Project description
FlareJax
FlareJax is a Python library for building neural networks and optimizers in Jax. It is designed to minimize the time between a new research idea and its implementation.
Features
- Mutable modules for quick and dirty modifications via
Module
- Serialization of modules via
@saveable
,save
, andload
- Systematically modifying modules by using
flatten
andunflatten
- Safely handling shared/cyclical references and static arguments through
filter_jit
- Commonly used NN layers and optimizers are included
- As a small codebase, it is relatively easy to understand and extend
Quick Example
Define new Modules by subclassing Module
. All methods are callable PyTrees.
@flr.saveable("Example:Linear") # optional, make saveable
class Linear(flr.Module):
def __init__(self, key: PRNGKeyArray, dim: int):
self.dim = dim
self.w = None
def __call__(self, x):
# lazy initialization dependent on the input shape
if self.w is None:
self.w = jrn.normal(key, (x.shape[-1], self.dim))
return x @ self.w
layer = Linear(jrn.PRNGKey(0), 3)
x = jnp.zeros((1, 4))
# the model is initialized after the first call
y = layer(x)
assert layer.w.shape == (4, 3)
For optimization, define a loss function, which takes the module as the first argument.
def loss_fn(module, x, y):
return jnp.mean((module(x) - y) ** 2)
opt = flr.opt.Adam(3e-4)
# automatically just-in-time compiled
opt, model, loss = flr.train(opt, model, loss_fn, x, y)
Models can be saved and loaded.
flr.save(layer, "model.npz")
# load the model
layer = flr.load("model.npz")
assert isinstance(layer, Linear)
Installation
FlareJax can be installed via pip. It requires Python 3.10 or higher and Jax 0.4.33 or higher.
pip install flarejax
See Also
- Jax Docs: Jax is a library for numerical computing that is designed to be composable and fast.
- Equinox library: FlareJax is heavily inspired by this awesome library.
- torch.nn.Module: Many of the principles of mutability are inspired by PyTorch's
torch.nn.Module
. - NNX Docs: NNX is a library for neural networks in flax that also supports mutability.
- Always feel free to reach out to me via email.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
flarejax-0.4.8.tar.gz
(19.9 kB
view details)
Built Distribution
flarejax-0.4.8-py3-none-any.whl
(23.9 kB
view details)
File details
Details for the file flarejax-0.4.8.tar.gz
.
File metadata
- Download URL: flarejax-0.4.8.tar.gz
- Upload date:
- Size: 19.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 49208a24904b508cbcd0dddcedd1057f06c89d8b8e375ba445e1abe2752e7fb5 |
|
MD5 | 0a4badb3d9f6c7e82b83c822bc7ae83b |
|
BLAKE2b-256 | 5aaabac31c6a5ff7aeff17ef964607cead2324ff4c1e615861e3753e9ce8e89d |
File details
Details for the file flarejax-0.4.8-py3-none-any.whl
.
File metadata
- Download URL: flarejax-0.4.8-py3-none-any.whl
- Upload date:
- Size: 23.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 44460f76cf9a26d737fc7dd6b70ad98ab2f96e0d0baed9f2d862115263154272 |
|
MD5 | af067abfadb2e7140fb71c014b771e6a |
|
BLAKE2b-256 | bbd4d1d49545d12a78ed22d6bd7739145c25b34fcf695e47cd69033af761f6fe |