Skip to main content

Immutable pytree modules classes with easy manipulation and serialization

Project description

FlareJax

Simple pytree module classes for Jax, strongly inspired by Equinox

  • Referential transparency via strict immutability
  • Safe serialization including hyperparameters
  • Bound methods and function transformations are also modules
  • Auxillary information in key paths for filtered transformations

Quick Examples

Modules work similar to dataclasses, but with the added benefit of being pytrees. Making them compatible with all Jax function transformations.

import flarejax as fj

class Linear(fj.Module):
    # The __init__ method is automatically generated
    w: jax.Array
    b: jax.Array

    # additional intialization methods via classmethods
    @classmethod
    def init(cls, key, dim_in, dim):
        w = jax.random.normal(key, (dim, dim_in)) * 0.02
        b = jax.numpy.zeros((dim,))
        return cls(w=w, b=b)

    def __call__(self, x):
        return self.w @ x + self.b

key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)

model = fj.Sequential(
    (
        Linear.init(key1, 3, 2),
        Linear.init(key2, 2, 5),
    )
)

The model can be serialized and deserialized using fj.save and fj.load.

fj.save("model.npz", model)
model = fj.load("model.npz")

Flarejax includes wrappers of the Jax function transformations, which return callable modules.

model = fj.VMap(model)
model = fj.Jit(model)

Installation

Memmpy can be installed directly from PyPI using pip. It requires Python 3.10+ and Jax 0.4.26+.

pip install flarejax

Design

Flarejax modules sacrifice some flexibility for the sake of a unified interface and safety. Flarejax code should alway be easy to reason about and should not contain any footguns from using python magic.

  1. Everything is immutable and
  2. module fields can be either jax arrays, other modules or json-like data.

This makes it harder to use other jax libraries in flarejax modules. It is recommended to wrap the needed functionality in a module. Most jax libraries should be compatible with flarejax modules, since they are simply callable pytrees.

Roadmap

  • Filtered grad transformation based on key paths
  • Pretty printing for modules
  • Rule to infer static arguments in jitted functions, possibly everything except JAX arrays

See also

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flarejax-0.3.8.tar.gz (9.7 kB view details)

Uploaded Source

Built Distribution

flarejax-0.3.8-py3-none-any.whl (13.0 kB view details)

Uploaded Python 3

File details

Details for the file flarejax-0.3.8.tar.gz.

File metadata

  • Download URL: flarejax-0.3.8.tar.gz
  • Upload date:
  • Size: 9.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.9

File hashes

Hashes for flarejax-0.3.8.tar.gz
Algorithm Hash digest
SHA256 c4a6c1e0037b8dc0a8b4df7c3ed5306d519c528a36792ec9769d77965bcfa76b
MD5 2e89c18c49b3357b0dc00d37f06d3f3c
BLAKE2b-256 cf8eb807e8a7793b6d168b107a3e82806b310c6638f5a158775a986d7e868e06

See more details on using hashes here.

File details

Details for the file flarejax-0.3.8-py3-none-any.whl.

File metadata

  • Download URL: flarejax-0.3.8-py3-none-any.whl
  • Upload date:
  • Size: 13.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.9

File hashes

Hashes for flarejax-0.3.8-py3-none-any.whl
Algorithm Hash digest
SHA256 fb1a612c0f5d31b64e1e29a5a3e3f36308f45cf2b903170d6e07c7b24ee0a475
MD5 bfdd64a8595a136fa291a1ae9b976472
BLAKE2b-256 fe9d6ed0a76692bbaf9f1b722fdbc68f102f7d32da18ce97cc7f3b2d712bf087

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page