Skip to main content

Flexible Modules for JAX

Project description

PyPI

🥷 Ninjax: Flexible Modules for JAX

Ninjax is a general and practical module system for JAX. It gives users full and transparent control over updating the state of each module, bringing flexibility to JAX and enabling new use cases.

Overview

Ninjax provides a simple and general nj.Module class.

  • Modules can store state for things like model parameters, Adam momentum buffer, BatchNorm statistics, recurrent state, etc.
  • Modules can read and write their state entries. For example, this allows modules to have train methods, because they can update their parameters from the inside.
  • Any method can initialize, read, and write state entries. This avoids the need for a special build() method or @compact decorator used in Flax.
  • Ninjax makes it easy to mix and match modules from different libraries, such as Flax and Haiku.
  • Instead of PyTrees, Ninjax state is a flat dict that maps string keys like /net/layer1/weights to jnp.arrays. This makes it easy to iterate over, modify, and save or load state.
  • Modules can specify typed hyperparameters using the dataclass syntax.

Installation

Ninjax is a single file, so you can just copy it to your project directory. Or you can install the package:

pip install ninjax

Quickstart

import flax
import jax
import jax.numpy as jnp
import ninjax as nj
import optax

Linear = nj.FromFlax(flax.linen.Dense)


class MyModel(nj.Module):

  lr: float = 1e-3

  def __init__(self, size):
    self.size = size
    # Define submodules upfront
    self.h1 = Linear(128, name='h1')
    self.h2 = Linear(128, name='h2')
    self.opt = optax.adam(self.lr)

  def predict(self, x):
    x = jax.nn.relu(self.h1(x))
    x = jax.nn.relu(self.h2(x))
    # Define submodules inline
    x = self.sub('h3', Linear, self.size, use_bias=False)(x)
    # Create state entries inline
    x += self.value('bias', jnp.zeros, self.size)
    # Update state entries inline
    self.write('bias', self.read('bias') + 0.1)
    return x

  def loss(self, x, y):
    return ((self.predict(x) - y) ** 2).mean()

  def train(self, x, y):
    # Take grads wrt. to submodules or state keys
    wrt = [self.h1, self.h2, f'{self.path}/h3', f'{self.path}/bias']
    loss, params, grads = nj.grad(self.loss, wrt)(x, y)
    # Update weights
    state = self.sub('optstate', nj.Tree, self.opt.init, params)
    updates, new_state = self.opt.update(grads, state.read(), params)
    params = optax.apply_updates(params, updates)
    nj.context().update(params)  # Store the new params
    state.write(new_state)       # Store new optimizer state
    return loss


# Create model and example data
model = MyModel(3, name='model')
x = jnp.ones((64, 32), jnp.float32)
y = jnp.ones((64, 3), jnp.float32)

# Populate initial state from one or more functions
state = {}
state = nj.init(model.train)(state, x, y, seed=0)
print(state['model/bias'])

# Purify for JAX transformations
train = jax.jit(nj.pure(model.train))

# Training loop
for x, y in [(x, y)] * 10:
  state, loss = train(state, x, y)
  print('Loss:', float(loss))

# Look at the parameters
print(state['model/bias'])

Questions

If you have a question, please file an issue.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ninjax-3.6.2.tar.gz (17.6 kB view details)

Uploaded Source

File details

Details for the file ninjax-3.6.2.tar.gz.

File metadata

  • Download URL: ninjax-3.6.2.tar.gz
  • Upload date:
  • Size: 17.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for ninjax-3.6.2.tar.gz
Algorithm Hash digest
SHA256 3eabc014d94d05177f3ade206dc43a7503cb130eb6d396da73eb607b5c20054a
MD5 6bbfce122bddfe4a8d46cdb6b18b3573
BLAKE2b-256 1198a6a2f324205e725521503c832cd8456ef20c872571cf501026d5f5a57366

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page