A stateful pytree library for training neural networks.
Project description
Introduction | Getting started | Functional programming | Examples | Modules | Fine-tuning
Introduction
PAX
is a stateful pytree library for training neural networks. The main class of PAX
is pax.Module
.
A pax.Module
object has two sides:
- It is a normal python object which can be modified and called.
- It is a pytree object whose leaves are
ndarray
's.
pax.Module
object manages the pytree and executes functions that depend on the pytree. As a pytree object, it can be input and output to JAX functions running on CPU/GPU/TPU cores.
Installation
Install from PyPI:
pip3 install pax3
Or install the latest version from Github:
pip3 install git+https://github.com/ntt123/pax.git
## or test mode to run tests and examples
pip3 install git+https://github.com/ntt123/pax.git#egg=pax3[test]
Getting started
import jax
import jax.numpy as jnp
import pax
class Counter(pax.Module):
bias: jnp.ndarray
counter: jnp.ndarray
def __init__(self, start_value: int = 0):
super().__init__()
self.register_parameter("bias", jnp.array(0.0))
self.register_state("counter", jnp.array(start_value))
def __call__(self, x):
self.counter = self.counter + 1
return self.counter * x + self.bias
def loss_fn(model: Counter, x: jnp.ndarray):
model, y = pax.module_and_value(model)(x)
loss = jnp.mean(jnp.square(x - y))
return loss, (loss, model)
grad_fn = jax.grad(loss_fn, has_aux=True, allow_int=True)
net = Counter(3)
x = jnp.array(10.)
grads, (loss, net) = grad_fn(net, x)
print(grads.counter) # (b'',)
print(grads.bias) # 60.0
There are few noteworthy points in the above example:
bias
is registered as a trainable parameter usingregister_parameter
method.counter
is registered as a non-trainable state usingregister_state
method.pax.module_and_value
transformsmodel.__call__
into a pure function that returns the updated model in its output.loss_fn
returns the updatedmodel
in the output.allow_int=True
to compute gradients with respect to integer ndarray leafcounter
.
PAX functional programming
pax.pure
Let "PAX function" mean functions whose inputs contain PAX modules.
It is a good practice to make PAX functions pure (no side effects).
Even though PAX modules are stateful objects, the modifications of PAX module's internal states are restricted.
Only PAX functions decorated by pax.pure
are allowed to modify PAX modules.
net = Counter(3)
net(0)
# ...
# ValueError: Cannot modify a module in immutable mode.
# Please do this computation inside a function decorated by `pax.pure`.
Furthermore, a decorated function can only access a copy of its inputs. Any modification on the copy will not affect the original inputs.
@pax.pure
def update_counter_wo_return(m: Counter):
m(0)
print(net.counter)
# 3
update_counter_wo_return(net)
print(net.counter) # the same counter
# 3
As a consequence, the only way to update an input module is to return it in the output.
@pax.pure
def update_counter(m: Counter):
m(0)
return m
print(net.counter)
# 3
net = update_counter(net)
print(net.counter) # increased by 1
# 4
pax.module_and_value
It is a good practice to keep functions decorated by pax.pure
as small as possible.
PAX provides the function pax.module_and_value
that transforms a module's method into a pure function. The pure function also returns the updated module in its output. For example:
net = Counter(3)
print(net.counter) # 3
net, y = pax.module_and_value(net)(0)
print(net.counter) # 4
In this example, pax.module_and_value
transforms net.__call__
into a pure function which returns the updated net
in its output.
PAX and other libraries
PAX module has several methods that are similar to Pytorch.
self.register_parameter(name, value)
registersname
as a trainable parameter.self.apply(func)
appliesfunc
on all modules ofself
recursively.self.train()
andself.eval()
returns a new module intrain/eval
mode.
PAX learns a lot from other libraries too:
- PAX borrows the idea that a module is also a pytree from treex and equinox.
- PAX uses the concept of trainable parameters and non-trainable states from dm-haiku.
- PAX uses objax's approach to implement optimizers as modules.
- PAX uses jmp library for supporting mixed precision.
- And of course, PAX is heavily influenced by jax functional programming approach.
Examples
A good way to learn about PAX
is to see examples in the examples/ directory.
Click to expand
Path | Description |
---|---|
char_rnn.py |
train a RNN language model on TPU. |
transformer/ |
train a Transformer language model on TPU. |
mnist.py |
train an image classifier on MNIST dataset. |
notebooks/VAE.ipynb |
train a variational autoencoder. |
notebooks/DCGAN.ipynb |
train a DCGAN model on Celeb-A dataset. |
notebooks/fine_tuning_resnet18.ipynb |
finetune a pretrained ResNet18 model on cats vs dogs dataset. |
notebooks/mixed_precision.ipynb |
train a U-Net image segmentation with mixed precision. |
mnist_mixed_precision.py (experimental) |
train an image classifier with mixed precision. |
wave_gru/ |
train a WaveGRU vocoder: convert mel-spectrogram to waveform. |
denoising_diffusion/ |
train a denoising diffusion model on Celeb-A dataset. |
Modules
At the moment, PAX includes:
pax.nn.Embed
,pax.nn.Linear
,pax.nn.{GRU, LSTM}
,pax.nn.{BatchNorm1D, BatchNorm2D, LayerNorm, GroupNorm}
,pax.nn.{Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose}
,pax.nn.{Dropout, Sequential, Identity, Lambda, RngSeq, EMA}
.
We are intent to add new modules in the near future.
Optimizers
PAX has its optimizers implemented in a separate library opax. The opax
library supports many common optimizers such as adam
, adamw
, sgd
, rmsprop
. Visit opax's GitHub repository for more information.
Module transformations
A module transformation is a pure function that transforms PAX modules into new PAX modules. A PAX program can be seen as a series of module transformations.
PAX provides several module transformations:
pax.select_{parameters,states}
: select parameter/state leaves.pax.update_{parameters,states}
: updates module's parameters/states.pax.enable_{train,eval}_mode
: turn on/off training mode.pax.(un)freeze_parameters
: freeze/unfreeze trainable parameters.pax.apply_mp_policy
: apply a mixed-precision policy.
Fine-tunning models
PAX's Module provides the pax.freeze_parameters
transformation to convert all trainable parameters to non-trainable states.
net = pax.nn.Sequential(
pax.nn.Linear(28*28, 64),
jax.nn.relu,
pax.nn.Linear(64, 10),
)
net = pax.freeze_parameters(net)
net = net.set(-1, pax.nn.Linear(64, 2))
After this, net.parameters()
will only return trainable parameters of the last layer.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pax3-0.4.1.dev0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1b065cc955924daa9fb26187965fbe152070f028ce65d250fa8f9910a121a2e3 |
|
MD5 | bc816b9a0a69d345f2cd302855119ea5 |
|
BLAKE2b-256 | 02271a89e8ca0a162d4a092d0c3aabaf747e3ad538873d726b66013d9558474a |