JAX Scalify: end-to-end scaled arithmetic.
Project description
JAX Scalify: end-to-end scaled arithmetic
JAX Scalify is a library implementing end-to-end scale propation and scaled arithmetic, allowing easy training and inference of deep neural networks in low precision (BF16, FP16, FP8).
Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Most of these works focus on ad-hoc approaches around scaling of matrix multiplications (and sometimes reduction operations). Scalify
is adopting a more systematic approach with end-to-end scale propagation, i.e. transforming the full computational graph into a ScaledArray
graph where every operation has ScaledArray
inputs and returns ScaledArray
:
@dataclass
class ScaledArray:
# Main data component, in low precision.
data: Array
# Scale, usually scalar, in FP32 or E8M0.
scale: Array
def __array__(self) -> Array:
# Tensor represented as a `ScaledArray`.
return data * scale.astype(self.data.dtype)
The main benefits of the scalify
approach are:
- Agnostic to neural-net model definition;
- Decoupling scaling from low-precision, reducing the computational overhead of dynamic rescaling;
- FP8 matrix multiplications and reductions as simple as a cast;
- Out-of-the-box support of FP16 (scaled) master weights and optimizer state;
- Composable with JAX ecosystem: Flax, Optax, ...
Scalify training loop example
A typical JAX training loop just requires a couple of modifications to take advantage of scalify
. More specifically:
- Represent input and state as
ScaledArray
using theas_scaled_array
method (or variations of it); - End-to-end scale propagation in
update
training method usingscalify
decorator; - (Optionally) add
dynamic_rescale
calls to improve low-precision accuracy and stability;
The following (simplified) example presents how to scalify
can be incorporated into a JAX training loop.
import jax_scalify as jsa
# Scalify transform on FWD + BWD + optimizer.
# Propagating scale in the computational graph.
@jsa.scalify
def update(state, data, labels):
# Forward and backward pass on the NN model.
loss, grads =
jax.grad(model)(state, data, labels)
# Optimizer applied on scaled state.
state = optimizer.apply(state, grads)
return loss, state
# Model + optimizer state.
state = (model.init(...), optimizer.init(...))
# Transform state to scaled array(s)
sc_state = jsa.as_scaled_array(state)
for (data, labels) in dataset:
# If necessary (e.g. images), scale input data.
data = jsa.as_scaled_array(data)
# State update, with full scale propagation.
sc_state = update(sc_state, data, labels)
# Optional dynamic rescaling of state.
sc_state = jsa.ops.dynamic_rescale_l2(sc_state)
As presented in the code above, the model state is represented as a JAX PyTree of ScaledArray
, propagated end-to-end through the model (forward and backward passes) as well as the optimizer.
A full collection of examples is available:
- Scalify quickstart notebook: basics of
ScaledArray
andscalify
transform; - MNIST FP16 training example: adapting JAX MNIST example to
scalify
; - MNIST FP8 training example: easy FP8 support in
scalify
; - MNIST Flax example:
scalify
Flax training, with Optax optimizer integration;
Installation
JAX Scalify can be directly installed from the github repository in Python virtual environment:
pip install git+https://github.com/graphcore-research/jax-scalify.git@main
Alternatively, for a local development setup:
git clone git@github.com:graphcore-research/jax-scalify.git
pip install -e ./
The major dependencies are numpy
, jax
and chex
libraries.
Documentation
Development
Running pre-commit
and pytest
on the JAX Scalify repository:
pip install pre-commit
pre-commit run --all-files
pytest -v ./tests
Python wheel can be built with the usual command python -m build
.
Graphcore IPU support
JAX Scalify v0.1 is compatible with experimental JAX on IPU, which can be installed in a Graphcore Poplar Python environnment:
pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html
Here are the common JAX libraries compatible with IPU:
pip install chex==0.1.6 flax==0.6.4 equinox==0.7.0 jaxtyping==0.2.8
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for jax_scalify-0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f3439626c317384230b1ff094d498bd93f5cb86a2d74fed938f93e7324ff81a3 |
|
MD5 | 8f95033e38aaa096018c2087a5c9e19a |
|
BLAKE2b-256 | 329b9ee1be7cf43e23a3990b6a02f5926da7b9cbbc8de8b38c02ea5933802c9a |