Immutable pytree modules classes with easy manipulation and serialization
Project description
FlareJax
Simple pytree module classes for Jax, strongly inspired by Equinox
- Referential transparency via strict immutability
- Safe serialization including hyperparameters
- Bound methods and function transformations are also modules
- Auxillary information in key paths for filtered transformations
Quick Examples
Modules work similar to dataclasses, but with the added benefit of being pytrees. Making them compatible with all Jax function transformations.
import flarejax as fj
class Linear(fj.Module):
# The __init__ method is automatically generated
w: jax.Array
b: jax.Array
# additional intialization methods via classmethods
@classmethod
def init(cls, key, dim_in, dim):
w = jax.random.normal(key, (dim, dim_in)) * 0.02
b = jax.numpy.zeros((dim,))
return cls(w=w, b=b)
def __call__(self, x):
return self.w @ x + self.b
key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)
model = fj.Sequential(
(
Linear.init(key1, 3, 2),
Linear.init(key2, 2, 5),
)
)
The model can be serialized and deserialized using fj.save
and fj.load
.
fj.save("model.npz", model)
model = fj.load("model.npz")
Flarejax includes wrappers of the Jax function transformations, which return callable modules.
model = fj.VMap(model)
model = fj.Jit(model)
Installation
Memmpy can be installed directly from PyPI using pip
. It requires Python 3.10+ and Jax 0.4.26+.
pip install flarejax
Design
Flarejax modules sacrifice some flexibility for the sake of a unified interface and safety. Flarejax code should alway be easy to reason about and should not contain any footguns from using python magic.
- Everything is immutable and
- module fields can be either jax arrays, other modules or json-like data.
This makes it harder to use other jax libraries in flarejax modules. It is recommended to wrap the needed functionality in a module. Most jax libraries should be compatible with flarejax modules, since they are simply callable pytrees.
Roadmap
- Filtered grad transformation based on key paths
- Pretty printing for modules
- Rule to infer static arguments in jitted functions, possibly everything except JAX arrays
See also
- The beautiful Equinox library
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file flarejax-0.3.9.tar.gz
.
File metadata
- Download URL: flarejax-0.3.9.tar.gz
- Upload date:
- Size: 9.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | be941a5929ca07e9d6c7f0e3c53a7f489a2b3fa26e08ba454b255b893b844673 |
|
MD5 | 40e87c2f6e649bdfe0f719af1deff6ef |
|
BLAKE2b-256 | 373708deb045e3abe02ee6e43cf8c5f3c1f522b9710edec9e329a0457b163a98 |
File details
Details for the file flarejax-0.3.9-py3-none-any.whl
.
File metadata
- Download URL: flarejax-0.3.9-py3-none-any.whl
- Upload date:
- Size: 13.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 67c9872cd4734e9d468dade2015d07326e7640539965f0fb05d72f88c854b24a |
|
MD5 | d268a0c243b7585507292d1b158f3298 |
|
BLAKE2b-256 | f1a9e0fef586e213b0882100d331d93a9979bc4ce2223d104c562386b1682747 |