A stateful pytree library for training neural networks.
Project description
Introduction | Getting started | Functional programming | Examples | Modules | Fine-tuning
Introduction
PAX
is a stateful pytree library for training neural networks. The main class of PAX
is pax.Module
.
A pax.Module
object has two sides:
- It is a normal python object which can be modified and called.
- It is a pytree object whose leaves are
ndarrays
.
pax.Module
object manages the pytree and executes functions that depend on the pytree. As a pytree object, it can be input and output to JAX functions running on CPU/GPU/TPU cores.
Installation
Install from PyPI:
pip install pax3
Or install the latest version from Github:
pip install git+https://github.com/ntt123/pax.git
## or test mode to run tests and examples
pip install git+https://github.com/ntt123/pax.git#egg=pax3[test]
Getting started
import jax, pax, jax.numpy as jnp
class Linear(pax.Module):
weight: jnp.ndarray
bias: jnp.ndarray
counter: jnp.ndarray
parameters = pax.parameters_method("weight", "bias")
def __init__(self):
super().__init__()
self.weight = jnp.array(0.0)
self.bias = jnp.array(0.0)
self.counter = jnp.array(0)
def __call__(self, x):
self.counter = self.counter + 1
return self.weight * x + self.bias
def loss_fn(model: Linear, x: jnp.ndarray, y: jnp.ndarray):
model, y_hat = pax.purecall(model, x)
loss = jnp.mean(jnp.square(y_hat - y))
return loss, (loss, model)
grad_fn = jax.grad(loss_fn, has_aux=True, allow_int=True)
net = Linear()
x, y = jnp.array(1.0), jnp.array(1.0)
grads, (loss, net) = grad_fn(net, x, y)
print(grads.counter) # (b'',)
print(grads.bias) # -2.0
There are a few noteworthy points in the above example:
weight
andbias
are trainable parameters by settingparameters = pax.parameters_method("weight", "bias")
.pax.purecall(model, x)
executesmodel(x)
and returns the updatedmodel
in the output.loss_fn
returns the updatedmodel
in the output.jax.grad(..., allow_int=True)
allows gradients with respect to integer ndarray leaves (e.g.,counter
).
PAX functional programming
pax.pure
It is a good practice to keep functions of PAX modules pure (no side effects).
Following this practice, the modifications of PAX module's internal states are restricted.
Only PAX functions decorated by pax.pure
are allowed to modify a copy of its input modules.
Any modification on the copy will not affect the original inputs.
As a consequence, the only way to update an input module is to return it in the output.
net = Linear()
net(0)
# ...
# ValueError: Cannot modify a module in immutable mode.
# Please do this computation inside a function decorated by `pax.pure`.
@pax.pure
def update_counter_wo_return(m: Linear):
m(0)
print(net.counter)
# 0
update_counter_wo_return(net)
print(net.counter) # the same counter
# 0
@pax.pure
def update_counter_and_return(m: Linear):
m(0)
return m
print(net.counter)
# 0
net = update_counter_and_return(net)
print(net.counter) # increased by 1
# 1
pax.purecall
For convenience, PAX provides the pax.purecall
function.
It is a shortcut for pax.pure(lambda f, x: [f, f(x)]
.
Note that the function also returns the updated module in its output. For example:
net = Linear()
print(net.counter) # 0
net, y = pax.purecall(net, 0)
print(net.counter) # 1
Replacing parts
PAX provides utility methods to modify a module in a functional way.
The replace
method creates a new module with attributes replaced.
For example, to replace weight
and bias
of a pax.Linear
module:
fc = pax.Linear(2, 2)
fc = fc.replace(weight=jnp.ones((2,2)), bias=jnp.zeros((2,)))
The replace_node
method replaces a pytree node of a module:
f = pax.Sequential(
pax.Linear(2, 3),
pax.Linear(3, 4),
)
f = f.replace_node(f[-1], pax.Linear(3, 5))
print(f.summary())
# Sequential
# ├── Linear(in_dim=2, out_dim=3, with_bias=True)
# └── Linear(in_dim=3, out_dim=5, with_bias=True)
PAX and other libraries
PAX learns a lot from other libraries:
- PAX borrows the idea that a module is also a pytree from treex and equinox.
- PAX uses the concept of trainable parameters and non-trainable states from dm-haiku.
- PAX has similar methods to PyTorch such as
model.apply()
,model.parameters()
,model.eval()
, etc. - PAX uses objax's approach to implement optimizers as modules.
- PAX uses jmp library for supporting mixed precision.
- And of course, PAX is heavily influenced by jax functional programming approach.
Examples
A good way to learn about PAX
is to see examples in the examples/ directory.
Click to expand
Path | Description |
---|---|
char_rnn.py |
train a RNN language model on TPU. |
transformer/ |
train a Transformer language model on TPU. |
mnist.py |
train an image classifier on MNIST dataset. |
notebooks/VAE.ipynb |
train a variational autoencoder. |
notebooks/DCGAN.ipynb |
train a DCGAN model on Celeb-A dataset. |
notebooks/fine_tuning_resnet18.ipynb |
finetune a pretrained ResNet18 model on cats vs dogs dataset. |
notebooks/mixed_precision.ipynb |
train a U-Net image segmentation with mixed precision. |
mnist_mixed_precision.py |
train an image classifier with mixed precision. |
wave_gru/ |
train a WaveGRU vocoder: convert mel-spectrogram to waveform. |
denoising_diffusion/ |
train a denoising diffusion model on Celeb-A dataset. |
Modules
At the moment, PAX includes:
pax.Embed
,pax.Linear
,pax.{GRU, LSTM}
,pax.{BatchNorm1D, BatchNorm2D, LayerNorm, GroupNorm}
,pax.{Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose}
,pax.{Dropout, Sequential, Identity, Lambda, RngSeq, EMA}
.
We are intent to add new modules in the near future.
Optimizers
PAX has its optimizers implemented in a separate library opax. The opax
library supports many common optimizers such as adam
, adamw
, sgd
, rmsprop
. Visit opax's GitHub repository for more information.
Module transformations
A module transformation is a pure function that transforms PAX modules into new PAX modules. A PAX program can be seen as a series of module transformations.
PAX provides several module transformations:
pax.select_{parameters,states}
: select parameter/state leaves.pax.update_{parameters,states}
: updates module's parameters/states.pax.enable_{train,eval}_mode
: turn on/off training mode.pax.(un)freeze_parameters
: freeze/unfreeze trainable parameters.pax.apply_mp_policy
: apply a mixed-precision policy.
Fine-tunning models
PAX's Module provides the pax.freeze_parameters
transformation to convert all trainable parameters to non-trainable states.
net = pax.Sequential(
pax.Linear(28*28, 64),
jax.nn.relu,
pax.Linear(64, 10),
)
net = pax.freeze_parameters(net)
net = net.set(-1, pax.Linear(64, 2))
After this, net.parameters()
will only return trainable parameters of the last layer.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.