Skip to main content

Flexible Modules for JAX

Project description

PyPI

🥷 Ninjax: Flexible Modules for JAX

Ninjax is a general and practical module system for JAX. It gives users full and transparent control over updating the state of each module, bringing flexibility to JAX and enabling new use cases.

Overview

Ninjax provides a simple and general nj.Module class.

  • Modules can store state for things like model parameters, Adam momentum buffer, BatchNorm statistics, recurrent state, etc.
  • Modules can read and write their state entries. For example, this allows modules to have train methods, because they can update their parameters from the inside.
  • Any method can initialize, read, and write state entries. This avoids the need for a special build() method or @compact decorator used in Flax.
  • Ninjax makes it easy to mix and match modules from different libraries, such as Flax and Haiku.
  • Instead of PyTrees, Ninjax state is a flat dict that maps string keys like /net/layer1/weights to jnp.arrays. This makes it easy to iterate over, modify, and save or load state.
  • Modules can specify typed hyperparameters using the dataclass syntax.

Installation

Ninjax is a single file, so you can just copy it to your project directory. Or you can install the package:

pip install ninjax

Quickstart

import flax
import jax
import jax.numpy as jnp
import ninjax as nj
import optax

Linear = nj.FromFlax(flax.linen.Dense)


class MyModel(nj.Module):

  lr: float = 1e-3

  def __init__(self, size):
    self.size = size
    # Define submodules upfront
    self.h1 = Linear(128, name='h1')
    self.h2 = Linear(128, name='h2')
    self.opt = optax.adam(self.lr)

  def predict(self, x):
    x = jax.nn.relu(self.h1(x))
    x = jax.nn.relu(self.h2(x))
    # Define submodules inline
    x = self.sub('h3', Linear, self.size, use_bias=False)(x)
    # Create state entries inline
    x += self.value('bias', jnp.zeros, self.size)
    # Update state entries inline
    self.write('bias', self.read('bias') + 0.1)
    return x

  def loss(self, x, y):
    return ((self.predict(x) - y) ** 2).mean()

  def train(self, x, y):
    # Take grads wrt. to submodules or state keys
    wrt = [self.h1, self.h2, f'{self.path}/h3', f'{self.path}/bias']
    loss, params, grads = nj.grad(self.loss, wrt)(x, y)
    # Update weights
    state = self.sub('optstate', nj.Tree, self.opt.init, params)
    updates, new_state = self.opt.update(grads, state.read(), params)
    params = optax.apply_updates(params, updates)
    nj.context().update(params)  # Store the new params
    state.write(new_state)       # Store new optimizer state
    return loss


# Create model and example data
model = MyModel(3, name='model')
x = jnp.ones((64, 32), jnp.float32)
y = jnp.ones((64, 3), jnp.float32)

# Populate initial state from one or more functions
state = {}
state = nj.init(model.train)(state, x, y, seed=0)
print(state['model/bias'])

# Purify for JAX transformations
train = jax.jit(nj.pure(model.train))

# Training loop
for x, y in [(x, y)] * 10:
  state, loss = train(state, x, y)
  print('Loss:', float(loss))

# Look at the parameters
print(state['model/bias'])

Questions

If you have a question, please file an issue.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ninjax-3.6.3.tar.gz (11.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ninjax-3.6.3-py3-none-any.whl (11.4 kB view details)

Uploaded Python 3

File details

Details for the file ninjax-3.6.3.tar.gz.

File metadata

  • Download URL: ninjax-3.6.3.tar.gz
  • Upload date:
  • Size: 11.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for ninjax-3.6.3.tar.gz
Algorithm Hash digest
SHA256 07e9c47ccae889f3c1c163caf60b4e413fc545515d5f9a1f57a76921c8639359
MD5 88bb50b84d9e80f089fa0c257af00c94
BLAKE2b-256 b767752204a83ff8cd478cdd0cf1cbcee21f7d42f5fa7d049749b3f996769978

See more details on using hashes here.

File details

Details for the file ninjax-3.6.3-py3-none-any.whl.

File metadata

  • Download URL: ninjax-3.6.3-py3-none-any.whl
  • Upload date:
  • Size: 11.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for ninjax-3.6.3-py3-none-any.whl
Algorithm Hash digest
SHA256 a9cd7c2f481ba6a7a2bb384397b277bfb1465b0bd59a82c044dd109f305a7b80
MD5 53dc0b5754d1b81284969f95340b46b6
BLAKE2b-256 b39bf7f2e7222e331d44d58f76ba7707a8faab8328f2d7afa69178cd9e38a073

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page