Skip to main content

Simple state handling for JAX

Project description

PyPI

🥷 Ninjax

Ninjax brings the flexibility of PyTorch and TensorFlow 2 to JAX. Ninjax is a simple state manager for JAX that makes it easy to have nested components that update their own state. It's intended to be used together with a neural network library, such as Flax or Haiku.

Quickstart

import haiku as hk
import jax
import jax.numpy as jnp
import ninjax as nj


class Model(nj.Module):

  def __init__(self, size, act=jax.nn.relu):
    self.size = size
    self.act = act
    self.h1 = nj.HaikuModule(hk.Linear, 128)
    self.h2 = nj.HaikuModule(hk.Linear, 128)
    self.h3 = nj.HaikuModule(hk.Linear, size)

  def __call__(self, x):
    x = self.act(self.h1(x))
    x = self.act(self.h2(x))
    x = self.h3(x)
    return x

  def train(self, x, y):
    self(x)  # Create weights needed for gradient.
    loss, grad = nj.grad(self.loss, [self.h1, self.h2, self.h3])(x, y)
    state = jax.tree_map(lambda p, g: p - 0.01 * g, state, grad)
    self.update(state)
    return loss

  def loss(self, x, y):
    return ((self(x) - y) ** 2).mean()


model = Model(8)
main = jax.random.PRNGKey(0)
state = {}
for x, y in dataset:
  rng, main = jax.random.split(main)
  state, loss = nj.run(model.train, state, rng, x, y)
  print('Loss:', float(loss))

How To

How can I use JIT compilation?

The nj.run() function makes the state your JAX code uses explicit, so it can be jitted and transformed freely:

model = Model()
train = jax.jit(functools.partial(nj.run, model.train))
train(state, rng, ...)

How can I compute gradients?

You can use jax.grad as normal for computing gradients with respect to explicit inputs of your function. To compute gradients with respect to Ninjax state, use nj.grad(fn, keys):

class Module(nj.Module):

  def train(self, x, y):
    params = self.state()
    loss, grads = nj.grad(self.loss, params.keys())(x, y)
    params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
    self.update(params)

The self.state(filter) method optionally accepts a regex pattern to select only a subset of the state dictionary. It also returns only state entries of the current module. To access the global state, use nj.state().

How can I define modules compactly?

You can use self.get(name, ctor, *args, **kwargs) inside methods of your modules. When called for the first time, it creates a new state entry from the constructor ctor(*args, **kwargs). Later calls return the existing entry:

class Module(nj.Module):

  def __call__(self, x):
    x = jax.nn.relu(self.get('h1', Linear, 128)(x))
    x = jax.nn.relu(self.get('h2', Linear, 128)(x))
    x = self.get('h3', Linear, 32)(x)
    return x

How can I use Haiku modules?

Haiku requires its modules to be passed through hk.transform and the initialized via transformed.init(rng, batch). Ninjax provides nj.HaikuModule to do this for you:

class Module(nj.Module):

  def __init__(self):
    self.mlp = nj.HaikuModule(hk.nets.MLP, [128, 128, 32])

  def __call__(self, x):
    return self.mlp(x)

You can also predefine a list of aliases for Haiku modules that you want to use frequently:

Linear = functools.partial(nj.HaikuModule, hk.Linear)
Conv2D = functools.partial(nj.HaikuModule, hk.Conv2D)
MLP = functools.partial(nj.HaikuModule, hk.nets.MLP)
# ...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ninjax-0.2.0.tar.gz (4.7 kB view details)

Uploaded Source

File details

Details for the file ninjax-0.2.0.tar.gz.

File metadata

  • Download URL: ninjax-0.2.0.tar.gz
  • Upload date:
  • Size: 4.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.26.0 requests-toolbelt/0.9.1 urllib3/1.26.7 tqdm/4.62.3 importlib-metadata/4.10.0 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.7

File hashes

Hashes for ninjax-0.2.0.tar.gz
Algorithm Hash digest
SHA256 cb454d200cc1aecab7acaece445750ba141ceb56f695f64a4f7465ad509d1e5e
MD5 739c23dd05073f7e90f12bdb103908ff
BLAKE2b-256 820149d32bb611887222ae14ab5bc41a5984cee0e8643db7fc73a39f8176c26e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page