Skip to main content

Pytree modules classes with easy manipulation and serialization

Project description

FlareJax

Simple pytree module classes for Jax, strongly inspired by Equinox

  • Referential transparency via strict immutability
  • Easy manipulation using .at & .set
  • Safe serialization including hyperparameters
  • Bound methods and function transformations are also modules
  • Auxillary information in key paths for filtered transformations

Installation

Memmpy can be installed directly from PyPI using pip. It requires Python 3.10+ and Jax 0.4.26+.

pip install flarejax

Quick Examples

Modules work similar to dataclasses, but with the added benefit of being pytrees. Making them compatible with all Jax function transformations.

import flarejax as fj

class Linear(fj.Module):
    # The __init__ method is automatically generated
    w: jax.Array
    b: jax.Array

    # Non-pytree fields are marked with leaf=True
    aux: None = fj.field(leaf=False, default=None)

    # additional intialization methods via classmethods
    @classmethod
    def init(cls, key, dim_in, dim):
        w = jax.random.normal(key, (dim, dim_in))
        b = jax.numpy.zeros((dim,))
        return cls(w=w, b=b)

    def __call__(self, x):
        return self.w @ x + self.b

key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)

model = fj.Sequential(
    (
        Linear.init(key1, 3, 2),
        Linear.init(key2, 2, 5),
    )
)

Although modules are immutable, modified copies can be created using the at property.

w_new = jax.numpy.ones((2, 3))
model = model.at[0].w.set(w_new)

Turning train mode off for the first layer is only a simple call to set.

model = model.at[0].config["train"].set(False)

The model can be serialized and deserialized using fj.save and fj.load.

fj.save("model.npz", model)
model = fj.load("model.npz")

Flarejax includes wrappers of the Jax function transformations, which return callable modules.

model = fj.VMap(model)
model = fj.Jit(model)

Roadmap

  • Filtered transformations based on key paths

See also

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flarejax-0.2.3.tar.gz (10.7 kB view hashes)

Uploaded Source

Built Distribution

flarejax-0.2.3-py3-none-any.whl (13.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page