A stateful pytree library for training neural networks.
Project description
Introduction | Getting started | Functional programming | Examples | Modules | Fine-tuning
Introduction
PAX
is a stateful pytree library for training neural networks. The main class in PAX
are pax.Module
.
A pax.Module
object has two sides:
- It is a normal python object which can be modified and called (it has
__call__
method). - It is a pytree object whose leaves are
ndarray
's.
pax.Module
object manages the pytree and executes functions that depend on the pytree. As a pytree object, it can be input and output to JAX functions running on CPU/GPU/TPU cores.
Installation
Install from PyPI:
pip3 install pax3
Or install the latest from github:
pip3 install git+https://github.com/ntt123/pax.git
## or test mode to run tests and examples
pip3 install git+https://github.com/ntt123/pax.git#egg=pax[test]
Getting started
import jax
import jax.numpy as jnp
import pax
class Counter(pax.Module):
def __init__(self, start_value: int = 0):
super().__init__()
self.register_parameter("bias", jnp.array(0.0))
self.register_state("counter", jnp.array(start_value))
def __call__(self, x):
self.counter = self.counter + 1
return self.counter * x + self.bias
@pax.pure
def loss_fn(model: Counter, x: jnp.ndarray):
y = model(x)
loss = jnp.mean(jnp.square(x - y))
return loss, (loss, model)
grad_fn = jax.grad(loss_fn, has_aux=True, allow_int=True)
net = Counter(3)
x = jnp.array(10.)
grads, (loss, net) = grad_fn(net, x)
print(grads.counter) # (b'',)
print(grads.bias) # 60.0
There are a few important things in the above example:
bias
is registered as a trainable parameter usingregister_parameter
method.counter
is registered as a non-trainable state usingregister_state
method.loss_fn
is decorated bypax.pure
and it returns the updatedmodel
in the output.allow_int=True
to compute gradients with respect tomodel
which contains integerndarray
leaves.
PAX functional programming
Let "PAX function" mean functions whose inputs contain PAX modules.
It is a good practice to make sure PAX functions have no side effects. This adheres to JAX functional programming mode.
Even though PAX modules are stateful objects, the modifications of PAX module's internal states are restricted.
Only PAX functions decorated by pax.pure
are allowed to modify PAX modules.
net = Counter(3)
net(0)
# ...
# ValueError: Cannot modify a module in immutable mode.
# Please do this computation inside a function decorated by `pax.pure`.
Furthermore, a decorated function can only access the copy of its of inputs. Any modification on the copy will not affect the original inputs.
@pax.pure
def update_counter_wo_return(m: Counter):
m(0)
print(net.counter)
# 3
update_counter_wo_return(net)
print(net.counter) # the same counter
# 3
As a consequence, the only way to update an input module is to return it in the output.
@pax.pure
def update_counter(m: Counter):
m(0)
return m
print(net.counter)
# 3
net = update_counter(net)
print(net.counter) # increased by 1
# 4
PAX and other libraries
PAX module has several methods that are similar to Pytorch.
self.register_parameter(name, value)
registersname
as a trainable parameter.self.apply(func)
appliesfunc
on all modules ofself
recursively.self.train()
andself.eval()
returns a new module intrain/eval
mode.
PAX learns a lot from other libraries too:
- PAX borrows the idea that a module is also a pytree from treex and equinox.
- PAX uses the concept of trainable parameters and non-trainable states from dm-haiku.
- PAX uses objax's approach to implement optimizers as modules.
- PAX uses jmp library for supporting mixed precision.
- And of course, PAX is heavily influenced by jax functional programming approach.
Examples
A good way to learn about PAX
is to see examples in the examples/ directory:
Path | Description |
---|---|
char_rnn.py |
train a RNN language model on TPU. |
transformer/ |
train a Transformer language model on TPU. |
mnist.py |
train an image classifier on MNIST dataset. |
notebooks/VAE.ipynb |
train a variational autoencoder. |
notebooks/DCGAN.ipynb |
train a DCGAN model on Celeb-A dataset. |
notebooks/fine_tuning_resnet18.ipynb |
finetune a pretrained ResNet18 model on cats vs dogs dataset. |
notebooks/mixed_precision.ipynb |
train a U-Net image segmentation with mixed precision. |
mnist_mixed_precision.py (experimental) |
train an image classifier with mixed precision. |
wave_gru/ |
train a WaveGRU vocoder: convert mel-spectrogram to waveform. |
denoising_diffusion/ |
train a denoising diffusion model on Celeb-A dataset. |
Modules
At the moment, PAX includes:
pax.nn.Embed
,pax.nn.Linear
,pax.nn.{GRU, LSTM}
,pax.nn.{BatchNorm1D, BatchNorm2D, LayerNorm, GroupNorm}
,pax.nn.{Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose}
,pax.nn.{Dropout, Sequential, Identity, Lambda, RngSeq, EMA}
.
We intent to add new modules in the near future.
Optimizers
PAX has its optimizers implemented in a separate library opax. The opax
library supports many common optimizers such as adam
, adamw
, sgd
, rmsprop
. Visit opax's github repository for more information.
Module transformations
A module transformation is a pure function that inputs PAX's modules and outputs PAX's modules. A PAX program can be seen as a series of module transformations.
PAX provides several module transformations:
pax.select_{parameters,states}
: select parameter/state leaves.pax.apply_gradients
: update model & optimizer using gradients.pax.update_{parameters,states}
: updates module's parameters/states.pax.enable_{train,eval}_mode
: turn on/off training mode.pax.(un)freeze_parameters
: freeze/unfreeze trainable parameters.pax.apply_mp_policy
: apply a mixed-precision policy.
Fine-tunning models
PAX's Module provides the pax.freeze_parameters
transformation to convert all trainable parameters to non-trainable states.
net = pax.nn.Sequential(
pax.nn.Linear(28*28, 64),
jax.nn.relu,
pax.nn.Linear(64, 10),
)
net = pax.freeze_parameters(net)
net = net.set(-1, pax.nn.Linear(64, 2))
After this, net.parameters()
will only return trainable parameters of the last layer.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.