Flexible Neural Networks in JAX
Project description
FlareJax
FlareJax is a Python library for building neural networks and optimizers in Jax. It is designed to minimize the time between a new research idea and its implementation.
Features
- Mutable modules for quick and dirty modifications via
Module - Serialization of modules via
@saveable,save, andload - Systematically modifying modules by using
flattenandunflatten - Safely handling shared/cyclical references and static arguments through
filter_jit - Commonly used NN layers and optimizers are included
- As a small codebase, it is relatively easy to understand and extend
Quick Example
Define new Modules by subclassing Module. All methods are callable PyTrees.
@flr.saveable("Example:Linear") # optional, make saveable
class Linear(flr.Module):
def __init__(self, key: PRNGKeyArray, dim: int):
self.dim = dim
self.w = None
def __call__(self, x):
# lazy initialization dependent on the input shape
if self.w is None:
self.w = jrn.normal(key, (x.shape[-1], self.dim))
return x @ self.w
layer = Linear(jrn.PRNGKey(0), 3)
x = jnp.zeros((1, 4))
# the model is initialized after the first call
y = layer(x)
assert layer.w.shape == (4, 3)
For optimization, define a loss function, which takes the module as the first argument.
def loss_fn(module, x, y):
return jnp.mean((module(x) - y) ** 2)
opt = flr.opt.Adam(3e-4)
# automatically just-in-time compiled
opt, model, loss = flr.train(opt, model, loss_fn, x, y)
Models can be saved and loaded.
flr.save(layer, "model.npz")
# load the model
layer = flr.load("model.npz")
assert isinstance(layer, Linear)
Installation
FlareJax can be installed via pip. It requires Python 3.10 or higher and Jax 0.4.33 or higher.
pip install flarejax
See Also
- Jax Docs: Jax is a library for numerical computing that is designed to be composable and fast.
- Equinox library: FlareJax is heavily inspired by this awesome library.
- torch.nn.Module: Many of the principles of mutability are inspired by PyTorch's
torch.nn.Module. - NNX Docs: NNX is a library for neural networks in flax that also supports mutability.
- Always feel free to reach out to me via email.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flarejax-0.4.11.tar.gz.
File metadata
- Download URL: flarejax-0.4.11.tar.gz
- Upload date:
- Size: 20.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b550a793b991f6442564f8c697d4b0507a319e371abdce93f36eaf711de2fe84
|
|
| MD5 |
ad64843b84be36980accc1ead9606a5d
|
|
| BLAKE2b-256 |
a4ec18244d8193cb758a047b8173cd9bdf8157a3de1cd7d0d8a654da0ae20271
|
File details
Details for the file flarejax-0.4.11-py3-none-any.whl.
File metadata
- Download URL: flarejax-0.4.11-py3-none-any.whl
- Upload date:
- Size: 24.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
31b32afa1f5793325c0ec2e1a3f01278125690b52c625aeb29355b1731f113e5
|
|
| MD5 |
a59ac8b06b4921ee847341d413c93f43
|
|
| BLAKE2b-256 |
9eaa0f92d59d380a1a78127aa74ed7027084ddaccbf0619c0d371824fd8827ab
|