Skip to main content

No project description provided

Project description

nnx

Nerual Networks for JAX

nnx is a lightweight module system for JAX that provides the same power as flax but with a simpler mental model and implementation. It is built on top of refx, which enables shared state, tractable mutability, and semantic partitioning. nnx also supports stateful transformations, allowing you to train your models efficiently.

Status

nnx is currently a proof of concept and is meant to explore the design space of a lightweight module system for JAX based on Refx.

Getting Started

To get started with nnx, first install the package using pip:

pip install nnx

Once you have installed nnx, you can define your modules as Pytrees. Here is an example of how to define a Linear module:

import nnx
import jax

# Modules are Pytrees
class Linear(nnx.Module):

    # mark parameter fields
    w: jax.Array = nnx.param()
    b: jax.Array = nnx.param()

    def __init__(self, din: int, dout: int):
        key = self.make_rng("params") # request an RNG key
        self.w = jax.random.uniform(key, (din, dout))
        self.b = jax.numpy.zeros((dout,))

    def __call__(self, x):
        return x @ self.w + self.b

In this example, Linear is a Pytree with two fields: w and b. The w field is marked as a parameter using nnx.ref, and the b field is marked as a parameter using nnx.param.

To initialize a Linear module, you can use the init method:

model = Linear.init(jax.random.PRNGKey(0))(12, 2)

This will create a Linear module with din=12 and dout=2.

Stateful Transformations

nnx supports stateful transformations, which allow you to train your models efficiently. Here is an example of how to define and use a stateful transformation with nnx:

@nnx.jit
def train_step(model, x, y):

    def loss_fn(model):
        y_pred = model(x)
        return jax.numpy.mean((y_pred - y) ** 2)
    
    # compute gradient
    grad = nnx.grad(loss_fn, wrt="params")(model)
    # sdg update
    model["params"] = jax.tree_map(lambda w, g: w - 0.1 * g, model["params"], grad)

# stateful update, no return !!!
train_step(model, x, y)

In this example, train_step is a stateful transformation that takes a model, x, and y as inputs. The loss_fn function computes the loss of the model, and nnx.grad computes the gradient of the loss with respect to the parameters of the model. Finally, the model is updated using stochastic gradient descent.

Shared State

In nnx, it's possible to create modules that share state between each other. This can be useful when designing complex neural network architectures, as it allows you to reuse certain layers and reduce the number of learnable parameters.

Here's an example of how to create a module with shared state:

class Block(nnx.Module):
    def __init__(self, linear: nnx.Linear):
        self.linear = linear
        self.bn = nnx.BatchNorm(2)

    def __call__(self, x):
        return nnx.relu(self.bn(self.linear(x)))

class Model(nnx.Module):
    def __init__(self):
        shared = nnx.Linear(2, 2)
        self.block1 = Block(shared)
        self.block2 = Block(shared)

    def __call__(self, x):
        x = self.block1(x)
        x = self.block2(x)
        return x

In this example, the Model module contains two instances of the Block module, each of which shares the same nnx.Linear module. To run the model you can use the apply method to set the use_running_average flag for all BatchNorm modules.

Here's an example of how to compute the loss for a Model instance:

def loss_fn(model: Model, x: jax.Array, y: jax.Array):
    y_pred = model.apply(use_running_average=False)(x)
    return jnp.mean((y - y_pred) ** 2)

It's worth noting that the state for the shared nnx.Linear module will be kept in sync at all times on both Block instances, including during gradient updates.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nnx-0.0.1.tar.gz (20.5 kB view hashes)

Uploaded Source

Built Distribution

nnx-0.0.1-py3-none-any.whl (23.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page