Skip to main content

REAX: A simple training framework for JAX-based projects

Project description

Coverage Tests Latest Version https://img.shields.io/pypi/wheel/reax.svg https://img.shields.io/pypi/pyversions/reax.svg https://img.shields.io/pypi/l/reax.svg

REAX: A simple training framework for JAX-based projects

REAX is based on PyTorch Lightning and tries to bring a similar level of easy-of-use and customizability to the world of training JAX models. Much of lightning’s API has been adopted with some modifications being made to accommodate JAX’s pure function based approach.

Quick start

pip install reax

REAX example

Define the training workflow. Here’s a toy example:

# main.py
# ! pip install torchvision
from functools import partial
import jax, optax, reax, flax.linen as linen
import torch.utils.data as data, torchvision as tv


class Autoencoder(linen.Module):
    def setup(self):
        super().__init__()
        self.encoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(3)])
        self.decoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(28 * 28)])

    def __call__(self, x):
        z = self.encoder(x)
        return self.decoder(z)


# --------------------------------
# Step 1: Define a LightningModule
# --------------------------------
# A ReaxModule (nn.Module subclass) defines a full *system*
# (ie: an LLM, diffusion model, autoencoder, or simple image classifier).


class ReaxAutoEncoder(reax.Module):
    def __init__(self):
        super().__init__()
        self.ae = Autoencoder()

    def setup(self, stage: "reax.Stage", batch) -> None:
        if self.parameters() is None:
            x = batch[0].reshape(len(batch[0]), -1)
            params = self.ae.init(self.rng_key(), x)
            self.set_parameters(params)

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def forward(self, x):
        embedding = jax.jit(self.ae.encoder.apply)(self.parameters()["params"]["encoder"], x)
        return embedding

    def training_step(self, batch, batch_idx):
        x = batch[0].reshape(len(batch[0]), -1)
        loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(self.parameters(), x, self.ae)
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        return loss, grads

    @staticmethod
    @partial(jax.jit, static_argnums=2)
    def loss_fn(params, x, model):
        predictions = model.apply(params, x)
        return optax.losses.squared_error(predictions, x).mean()

    def configure_optimizers(self):
        opt = optax.adam(learning_rate=1e-3)
        state = opt.init(self.parameters())
        return opt, state


# -------------------
# Step 2: Define data
# -------------------
dataset = tv.datasets.MNIST(".", download=True, transform=jax.numpy.asarray)
train, val = data.random_split(dataset, [55000, 5000])

# -------------------
# Step 3: Train
# -------------------
autoencoder = ReaxAutoEncoder()
trainer = reax.Trainer(autoencoder)
trainer.fit(reax.ReaxDataLoader(train), reax.ReaxDataLoader(val))

Here, we reproduce an example from PyTorch Lightning, so we use torch vision to fetch the data, but for real models there’s no need to use this or pytorch at all. Run the model on the terminal

pip install reax torchvision
python main.py

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

reax-0.3.0.tar.gz (85.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

reax-0.3.0-py3-none-any.whl (111.7 kB view details)

Uploaded Python 3

File details

Details for the file reax-0.3.0.tar.gz.

File metadata

  • Download URL: reax-0.3.0.tar.gz
  • Upload date:
  • Size: 85.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.32.3

File hashes

Hashes for reax-0.3.0.tar.gz
Algorithm Hash digest
SHA256 f93c497480f4a05884980a3f4ae0e29564cc9120c5939c3b2697c26281cb6f88
MD5 3f34b202526e15734abca3af0ae05b5b
BLAKE2b-256 ab2276624fd05fe5b1bb7b68a3e4e8dcd0a6c353699db4306eb41869208c4b0e

See more details on using hashes here.

File details

Details for the file reax-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: reax-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 111.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.32.3

File hashes

Hashes for reax-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 430b215d7f9f0fc9253405f1b955951e70d668bb325e9f4ec030f9a0a7b1b141
MD5 7e6637813bde494e17812c39bfb7aeed
BLAKE2b-256 cc47c4d38fdfc0f08416b0e1d036d1f21a02bdf9181f93eac106d8d380586ad4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page