Pytree modules classes with easy manipulation and serialization
Project description
FlareJax
Simple pytree module classes for Jax, strongly inspired by Equinox
- Referential transparency via strict immutability
- Easy manipulation using
.at
&.set
- Safe serialization including hyperparameters
- Bound methods and function transformations are also modules
- Auxillary information in key paths for filtered transformations
Installation
Memmpy can be installed directly from PyPI using pip
. It requires Python 3.10+ and Jax 0.4.26+.
pip install flarejax
Quick Examples
Modules work similar to dataclasses, but with the added benefit of being pytrees. Making them compatible with all Jax function transformations.
import flarejax as fj
class Linear(fj.Module):
# The __init__ method is automatically generated
w: jax.Array
b: jax.Array
# Non-pytree fields are marked with leaf=True
aux: None = fj.field(leaf=False, default=None)
# additional intialization methods via classmethods
@classmethod
def init(cls, key, dim_in, dim):
w = jax.random.normal(key, (dim, dim_in))
b = jax.numpy.zeros((dim,))
return cls(w=w, b=b)
def __call__(self, x):
return self.w @ x + self.b
key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)
model = fj.Sequential(
(
Linear.init(key1, 3, 2),
Linear.init(key2, 2, 5),
)
)
Although modules are immutable, modified copies can be created using the at
property.
w_new = jax.numpy.ones((2, 3))
model = model.at[0].w.set(w_new)
Turning train mode off for the first layer is only a simple call to set
.
model = model.at[0].config["train"].set(False)
The model can be serialized and deserialized using fj.save
and fj.load
.
fj.save("model.npz", model)
model = fj.load("model.npz")
Flarejax includes wrappers of the Jax function transformations, which return callable modules.
model = fj.VMap(model)
model = fj.Jit(model)
Roadmap
- Filtered transformations based on key paths
See also
- The beautiful Equinox library
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.