Skip to main content

Flexible Neural Networks in JAX

Project description

FlareJax

FlareJax is a Python library for building neural networks and optimizers in Jax. It is designed to minimize the time between a new research idea and its implementation.

Features

  • Mutable modules for quick and dirty modifications via Module
  • Serialization of modules via @saveable, save, and load
  • Systematically modifying modules by using flatten and unflatten
  • Safely handling shared/cyclical references and static arguments through filter_jit
  • Commonly used NN layers and optimizers are included
  • As a small codebase, it is relatively easy to understand and extend

Quick Example

Define new Modules by subclassing Module. All methods are callable PyTrees.

@flr.saveable("Example:Linear")  # optional, make saveable
class Linear(flr.Module):
    def __init__(self, key: PRNGKeyArray, dim: int):
        self.dim = dim
        self.w = None

    def __call__(self, x):
        # lazy initialization dependent on the input shape
        if self.w is None:
            self.w = jrn.normal(key, (x.shape[-1], self.dim))

        return x @ self.w

layer = Linear(jrn.PRNGKey(0), 3)
x = jnp.zeros((1, 4))

# the model is initialized after the first call
y = layer(x)
assert layer.w.shape == (4, 3)

For optimization, define a loss function, which takes the module as the first argument.

def loss_fn(module, x, y):
    return jnp.mean((module(x) - y) ** 2)

opt = flr.opt.Adam(3e-4)

# automatically just-in-time compiled
opt, model, loss = flr.train(opt, model, loss_fn, x, y)

Models can be saved and loaded.

flr.save(layer, "model.npz")

# load the model
layer = flr.load("model.npz")
assert isinstance(layer, Linear)

Installation

FlareJax can be installed via pip. It requires Python 3.10 or higher and Jax 0.4.33 or higher.

pip install flarejax

See Also

  • Jax Docs: Jax is a library for numerical computing that is designed to be composable and fast.
  • Equinox library: FlareJax is heavily inspired by this awesome library.
  • torch.nn.Module: Many of the principles of mutability are inspired by PyTorch's torch.nn.Module.
  • NNX Docs: NNX is a library for neural networks in flax that also supports mutability.
  • Always feel free to reach out to me via email.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flarejax-0.4.8.tar.gz (19.9 kB view details)

Uploaded Source

Built Distribution

flarejax-0.4.8-py3-none-any.whl (23.9 kB view details)

Uploaded Python 3

File details

Details for the file flarejax-0.4.8.tar.gz.

File metadata

  • Download URL: flarejax-0.4.8.tar.gz
  • Upload date:
  • Size: 19.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.3

File hashes

Hashes for flarejax-0.4.8.tar.gz
Algorithm Hash digest
SHA256 49208a24904b508cbcd0dddcedd1057f06c89d8b8e375ba445e1abe2752e7fb5
MD5 0a4badb3d9f6c7e82b83c822bc7ae83b
BLAKE2b-256 5aaabac31c6a5ff7aeff17ef964607cead2324ff4c1e615861e3753e9ce8e89d

See more details on using hashes here.

File details

Details for the file flarejax-0.4.8-py3-none-any.whl.

File metadata

  • Download URL: flarejax-0.4.8-py3-none-any.whl
  • Upload date:
  • Size: 23.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.3

File hashes

Hashes for flarejax-0.4.8-py3-none-any.whl
Algorithm Hash digest
SHA256 44460f76cf9a26d737fc7dd6b70ad98ab2f96e0d0baed9f2d862115263154272
MD5 af067abfadb2e7140fb71c014b771e6a
BLAKE2b-256 bbd4d1d49545d12a78ed22d6bd7739145c25b34fcf695e47cd69033af761f6fe

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page