Skip to main content

A stateful pytree library for training neural networks.

Project description

logo

Introduction | Getting started | Functional programming | Examples | Modules | Fine-tuning

pytest docs pypi

Introduction

PAX is a stateful pytree library for training neural networks. The main class of PAX is pax.Module.

A pax.Module object has two sides:

  • It is a normal python object which can be modified and called.
  • It is a pytree object whose leaves are ndarrays.

pax.Module object manages the pytree and executes functions that depend on the pytree. As a pytree object, it can be input and output to JAX functions running on CPU/GPU/TPU cores.

Installation

Install from PyPI:

pip install pax3

Or install the latest version from Github:

pip install git+https://github.com/ntt123/pax.git

## or test mode to run tests and examples
pip install git+https://github.com/ntt123/pax.git#egg=pax3[test]

Getting started

import jax, pax, jax.numpy as jnp

class Linear(pax.Module):
    weight: jnp.ndarray
    bias: jnp.ndarray
    counter: jnp.ndarray
    parameters = pax.parameters_method("weight", "bias")

    def __init__(self):
        super().__init__()
        self.weight = jnp.array(0.0)
        self.bias = jnp.array(0.0)
        self.counter = jnp.array(0)

    def __call__(self, x):
        self.counter = self.counter + 1
        return self.weight * x + self.bias

def loss_fn(model: Linear, x: jnp.ndarray, y: jnp.ndarray):
    model, y_hat = pax.purecall(model, x)
    loss = jnp.mean(jnp.square(y_hat - y))
    return loss, (loss, model)

grad_fn = jax.grad(loss_fn, has_aux=True, allow_int=True)

net = Linear()
x, y = jnp.array(1.0), jnp.array(1.0)
grads, (loss, net) = grad_fn(net, x, y)
print(grads.counter)  # (b'',)
print(grads.bias)  # -2.0

There are a few noteworthy points in the above example:

  • weight and bias are trainable parameters by setting parameters = pax.parameters_method("weight", "bias").
  • pax.purecall(model, x) executes model(x) and returns the updated model in the output.
  • loss_fn returns the updated model in the output.
  • jax.grad(..., allow_int=True) allows gradients with respect to integer ndarray leaves (e.g., counter).

PAX functional programming

pax.pure

It is a good practice to keep functions of PAX modules pure (no side effects).

Following this practice, the modifications of PAX module's internal states are restricted. Only PAX functions decorated by pax.pure are allowed to modify a copy of its input modules. Any modification on the copy will not affect the original inputs. As a consequence, the only way to update an input module is to return it in the output.

net = Linear()
net(0)
# ...
# ValueError: Cannot modify a module in immutable mode.
# Please do this computation inside a function decorated by `pax.pure`.

@pax.pure
def update_counter_wo_return(m: Linear):
    m(0)

print(net.counter)
# 0
update_counter_wo_return(net)
print(net.counter) # the same counter
# 0

@pax.pure
def update_counter_and_return(m: Linear):
    m(0)
    return m

print(net.counter)
# 0
net = update_counter_and_return(net)
print(net.counter) # increased by 1
# 1

pax.purecall

For convenience, PAX provides the pax.purecall function. It is a shortcut for pax.pure(lambda f, x: [f, f(x)]. Note that the function also returns the updated module in its output. For example:

net = Linear()
print(net.counter) # 0
net, y = pax.purecall(net, 0)
print(net.counter) # 1

Replacing parts

PAX provides utility methods to modify a module in a functional way.

The replace method creates a new module with attributes replaced. For example, to replace weight and bias of a pax.Linear module:

fc = pax.Linear(2, 2)
fc = fc.replace(weight=jnp.ones((2,2)), bias=jnp.zeros((2,)))

The replace_node method replaces a pytree node of a module:

f = pax.Sequential(
    pax.Linear(2, 3),
    pax.Linear(3, 4),
)

f = f.replace_node(f[-1], pax.Linear(3, 5))
print(f.summary())
# Sequential
# ├── Linear(in_dim=2, out_dim=3, with_bias=True)
# └── Linear(in_dim=3, out_dim=5, with_bias=True)

PAX and other libraries

PAX learns a lot from other libraries:

  • PAX borrows the idea that a module is also a pytree from treex and equinox.
  • PAX uses the concept of trainable parameters and non-trainable states from dm-haiku.
  • PAX has similar methods to PyTorch such as model.apply(), model.parameters(), model.eval(), etc.
  • PAX uses objax's approach to implement optimizers as modules.
  • PAX uses jmp library for supporting mixed precision.
  • And of course, PAX is heavily influenced by jax functional programming approach.

Examples

A good way to learn about PAX is to see examples in the examples/ directory.

Click to expand
Path Description
char_rnn.py train a RNN language model on TPU.
transformer/ train a Transformer language model on TPU.
mnist.py train an image classifier on MNIST dataset.
notebooks/VAE.ipynb train a variational autoencoder.
notebooks/DCGAN.ipynb train a DCGAN model on Celeb-A dataset.
notebooks/fine_tuning_resnet18.ipynb finetune a pretrained ResNet18 model on cats vs dogs dataset.
notebooks/mixed_precision.ipynb train a U-Net image segmentation with mixed precision.
mnist_mixed_precision.py train an image classifier with mixed precision.
wave_gru/ train a WaveGRU vocoder: convert mel-spectrogram to waveform.
denoising_diffusion/ train a denoising diffusion model on Celeb-A dataset.

Modules

At the moment, PAX includes:

  • pax.Embed,
  • pax.Linear,
  • pax.{GRU, LSTM},
  • pax.{BatchNorm1D, BatchNorm2D, LayerNorm, GroupNorm},
  • pax.{Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose},
  • pax.{Dropout, Sequential, Identity, Lambda, RngSeq, EMA}.

We are intent to add new modules in the near future.

Optimizers

PAX has its optimizers implemented in a separate library opax. The opax library supports many common optimizers such as adam, adamw, sgd, rmsprop. Visit opax's GitHub repository for more information.

Module transformations

A module transformation is a pure function that transforms PAX modules into new PAX modules. A PAX program can be seen as a series of module transformations.

PAX provides several module transformations:

  • pax.select_{parameters,states}: select parameter/state leaves.
  • pax.update_{parameters,states}: updates module's parameters/states.
  • pax.enable_{train,eval}_mode: turn on/off training mode.
  • pax.(un)freeze_parameters: freeze/unfreeze trainable parameters.
  • pax.apply_mp_policy: apply a mixed-precision policy.

Fine-tunning models

PAX's Module provides the pax.freeze_parameters transformation to convert all trainable parameters to non-trainable states.

net = pax.Sequential(
    pax.Linear(28*28, 64),
    jax.nn.relu,
    pax.Linear(64, 10),
)

net = pax.freeze_parameters(net) 
net = net.set(-1, pax.Linear(64, 2))

After this, net.parameters() will only return trainable parameters of the last layer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pax3-0.5.2.tar.gz (55.7 kB view hashes)

Uploaded Source

Built Distribution

pax3-0.5.2-py3-none-any.whl (70.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page