Skip to main content

A stateful pytree library for training neural networks.

Project description

logo

Introduction | Getting started | Functional programming | Examples | Modules | Fine-tuning

pytest docs pypi

Introduction

PAX is a JAX-based library for training neural networks.

PAX modules are registered as JAX pytree, therefore, they can be input or output of JAX transformations such as jax.jit, jax.grad, etc. This makes programming with modules very convenient and easy to understand.

Installation

Install from PyPI:

pip install pax3

Or install the latest version from Github:

pip install git+https://github.com/ntt123/pax.git

## or test mode to run tests and examples
pip install git+https://github.com/ntt123/pax.git#egg=pax3[test]

Getting started

Below is a simple example of a Linear module.

import jax.numpy as jnp
import pax

class Linear(pax.Module):
    weight: jnp.ndarray
    bias: jnp.ndarray
    parameters = pax.parameters_method("weight", "bias")

    def __init__(self):
        super().__init__()
        self.weight = jnp.array(0.0)
        self.bias = jnp.array(0.0)

    def __call__(self, x):
        return self.weight * x + self.bias

The implementation is very similar to a normal python class. However, we need an additional line

    parameters = pax.parameters_method("weight", "bias")

to declare that weight and bias are trainable parameters of the Linear module.

PAX functional programming

pax.pure

A PAX module can have internal states. For example, below is a simple Counter module with an internal counter.

class Counter(pax.Module):
    count : jnp.ndarray

    def __init__(self):
        super().__init__()
        self.count = jnp.array(0)
    
    def __call__(self):
        self.count = self.count + 1
        return self.count

However, PAX aims to guarantee that modules will have no side effects from the outside point of view. Therefore, the modifications of these internal states are restricted. For example, we get an error when trying to call Counter directly.

counter = Counter()
count = counter()
# ...
# ----> 9         self.count = self.count + 1
# ...
# ValueError: Cannot modify a module in immutable mode.
# Please do this computation inside a function decorated by `pax.pure`.

Only functions decorated by pax.pure are allowed to modify input module's internal states.

@pax.pure
def update_counter(counter: Counter):
    count = counter()
    return counter, count

counter, count = update_counter(counter)
print(counter.count, count)
# 1 1

Note that we have to return counter in the output of update_counter, otherwise, the counter object will not be updated. This is because pax.pure only provides update_counter a copy of the counter object.

pax.purecall

For convenience, PAX provides the pax.purecall function. It is a shortcut for pax.pure(lambda f, x: [f, f(x)]).

Instead of implementing an update_counter function, we can do the same thing with:

counter, count = pax.purecall(counter)
print(counter.count, count)
# 2, 2

Replacing parts

PAX provides utility methods to modify a module in a functional way.

The replace method creates a new module with attributes replaced. For example, to replace weight and bias of a pax.Linear module:

fc = pax.Linear(2, 2)
fc = fc.replace(weight=jnp.ones((2,2)), bias=jnp.zeros((2,)))

The replace_node method replaces a pytree node of a module:

f = pax.Sequential(
    pax.Linear(2, 3),
    pax.Linear(3, 4),
)

f = f.replace_node(f[-1], pax.Linear(3, 5))
print(f.summary())
# Sequential
# ├── Linear(in_dim=2, out_dim=3, with_bias=True)
# └── Linear(in_dim=3, out_dim=5, with_bias=True)

PAX and other libraries

PAX learns a lot from other libraries:

  • PAX borrows the idea that a module is also a pytree from treex and equinox.
  • PAX uses the concept of trainable parameters and non-trainable states from dm-haiku.
  • PAX has similar methods to PyTorch such as model.apply(), model.parameters(), model.eval(), etc.
  • PAX uses objax's approach to implement optimizers as modules.
  • PAX uses jmp library for supporting mixed precision.
  • And of course, PAX is heavily influenced by jax functional programming approach.

Examples

A good way to learn about PAX is to see examples in the examples/ directory.

Click to expand
Path Description
char_rnn.py train a RNN language model on TPU.
transformer/ train a Transformer language model on TPU.
mnist.py train an image classifier on MNIST dataset.
notebooks/VAE.ipynb train a variational autoencoder.
notebooks/DCGAN.ipynb train a DCGAN model on Celeb-A dataset.
notebooks/fine_tuning_resnet18.ipynb finetune a pretrained ResNet18 model on cats vs dogs dataset.
notebooks/mixed_precision.ipynb train a U-Net image segmentation with mixed precision.
mnist_mixed_precision.py train an image classifier with mixed precision.
wave_gru/ train a WaveGRU vocoder: convert mel-spectrogram to waveform.
denoising_diffusion/ train a denoising diffusion model on Celeb-A dataset.

Modules

At the moment, PAX includes:

  • pax.Embed,
  • pax.Linear,
  • pax.{GRU, LSTM},
  • pax.{BatchNorm1D, BatchNorm2D, LayerNorm, GroupNorm},
  • pax.{Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose},
  • pax.{Dropout, Sequential, Identity, Lambda, RngSeq, EMA}.

Optimizers

PAX has its optimizers implemented in a separate library opax. The opax library supports many common optimizers such as adam, adamw, sgd, rmsprop. Visit opax's GitHub repository for more information.

Fine-tunning models

PAX's Module provides the pax.freeze_parameters transformation to convert all trainable parameters to non-trainable states.

net = pax.Sequential(
    pax.Linear(28*28, 64),
    jax.nn.relu,
    pax.Linear(64, 10),
)

net = pax.freeze_parameters(net) 
net = net.set(-1, pax.Linear(64, 2))

After this, net.parameters() will only return trainable parameters of the last layer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pax3-0.5.9.tar.gz (55.7 kB view details)

Uploaded Source

Built Distribution

pax3-0.5.9-py3-none-any.whl (71.1 kB view details)

Uploaded Python 3

File details

Details for the file pax3-0.5.9.tar.gz.

File metadata

  • Download URL: pax3-0.5.9.tar.gz
  • Upload date:
  • Size: 55.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for pax3-0.5.9.tar.gz
Algorithm Hash digest
SHA256 d0dd35c290076a24c1271d79d0174d913a2eb1197b5afe09019c99ae79e47b85
MD5 8cdf87c48beee0022f499f0844850b31
BLAKE2b-256 770c087de3593540dda2d08d38c6ff64141cea8828832a76829e2a6a4e92b350

See more details on using hashes here.

File details

Details for the file pax3-0.5.9-py3-none-any.whl.

File metadata

  • Download URL: pax3-0.5.9-py3-none-any.whl
  • Upload date:
  • Size: 71.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for pax3-0.5.9-py3-none-any.whl
Algorithm Hash digest
SHA256 f0f0fe4318cd47ae93ba6449e8c9d0eb8ba6bc01be9f49606ec812cc4c2b6eaa
MD5 b1f3bbae7251671fb27ce086784bc3a8
BLAKE2b-256 d40fc11dbc504be5275350c5778c48adb8c08055e236ac58e17b0f484ee3849c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page