Skip to main content

REAX: A simple training framework for JAX-based projects

Project description

Coverage Tests Latest Version https://img.shields.io/pypi/wheel/reax.svg https://img.shields.io/pypi/pyversions/reax.svg https://img.shields.io/pypi/l/reax.svg

REAX: A simple training framework for JAX-based projects

REAX is based on PyTorch Lightning and tries to bring a similar level of easy-of-use and customizability to the world of training JAX models. Much of lightning’s API has been adopted with some modifications being made to accommodate JAX’s pure function based approach.

Quick start

pip install reax

REAX example

Define the training workflow. Here’s a toy example:

# main.py
# ! pip install torchvision
from functools import partial
import jax, optax, reax, flax.linen as linen
import torch.utils.data as data, torchvision as tv


class Autoencoder(linen.Module):
    def setup(self):
        super().__init__()
        self.encoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(3)])
        self.decoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(28 * 28)])

    def __call__(self, x):
        z = self.encoder(x)
        return self.decoder(z)


# --------------------------------
# Step 1: Define a REAX Module
# --------------------------------
# A ReaxModule (nn.Module subclass) defines a full *system*
# (ie: an LLM, diffusion model, autoencoder, or simple image classifier).
class ReaxAutoEncoder(reax.Module):
    def __init__(self):
        super().__init__()
        self.ae = Autoencoder()

    def setup(self, stage: "reax.Stage", batch) -> None:
        if self.parameters() is None:
            x = batch[0].reshape(len(batch[0]), -1)
            params = self.ae.init(self.rng_key(), x)
            self.set_parameters(params)

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def forward(self, x):
        embedding = jax.jit(self.ae.encoder.apply)(self.parameters()["params"]["encoder"], x)
        return embedding

    def training_step(self, batch, batch_idx):
        x = batch[0].reshape(len(batch[0]), -1)
        loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(self.parameters(), x, self.ae)
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        return loss, grads

    @staticmethod
    @partial(jax.jit, static_argnums=2)
    def loss_fn(params, x, model):
        predictions = model.apply(params, x)
        return optax.losses.squared_error(predictions, x).mean()

    def configure_optimizers(self):
        opt = optax.adam(learning_rate=1e-3)
        state = opt.init(self.parameters())
        return opt, state


# -------------------
# Step 2: Define data
# -------------------
dataset = tv.datasets.MNIST(".", download=True, transform=jax.numpy.asarray)
train, val = data.random_split(dataset, [55000, 5000])

# -------------------
# Step 3: Train
# -------------------
autoencoder = ReaxAutoEncoder()
trainer = reax.Trainer(autoencoder)
trainer.fit(reax.ReaxDataLoader(train), reax.ReaxDataLoader(val))

Here, we reproduce an example from PyTorch Lightning, so we use torch vision to fetch the data, but for real models there’s no need to use this or pytorch at all. Run the model on the terminal

pip install reax torchvision
python main.py

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

reax-0.5.0.tar.gz (110.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

reax-0.5.0-py3-none-any.whl (143.4 kB view details)

Uploaded Python 3

File details

Details for the file reax-0.5.0.tar.gz.

File metadata

  • Download URL: reax-0.5.0.tar.gz
  • Upload date:
  • Size: 110.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.32.3

File hashes

Hashes for reax-0.5.0.tar.gz
Algorithm Hash digest
SHA256 3d219c95fc1b6ccbae532b18aa61dbd60d449a3d7b9ff74320baf9a652f7de25
MD5 568e336f25d61184bba304cf8d6b54a9
BLAKE2b-256 c0483173d71009dc78029872c6dbc1b42be90f6aa53cdf7b26ad411eec841ebc

See more details on using hashes here.

File details

Details for the file reax-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: reax-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 143.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.32.3

File hashes

Hashes for reax-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a442479774621913a6e01f8da7c70686b056a652afe9110f001b0c1fe196ed72
MD5 ad1ed6ee9ceebfdd5426f43e575cc94b
BLAKE2b-256 7892f92a1b8389eab34956bc0dbd7d3553aa3055372b15f836aee11456ca93e7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page