Skip to main content

Multiple dispatch in JAX via custom interpreters.

Project description

Quax

JAX + multiple dispatch + custom array-ish objects

For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul (W + AB)v into the more-efficient Wv + ABv.

Applications include:

  • LoRA weight matrices
  • symbolic zeros
  • arrays with named dimensions
  • structured (e.g. tridiagonal) matrices
  • sparse arrays
  • quantised arrays
  • arrays with physical units attached
  • etc! (See the built-in quax.examples library for most of the above!)

This works via a custom JAX transform. Take an existing JAX program, wrap it in a quax.quaxify, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!

(Just like how jax.vmap takes a program, but reinterprets each operation as its batched version, so to will quax.quaxify take a program and reinterpret each operation according to what array-ish types are passed.)

Installation

pip install quax

Documentation

Available at https://nstarman.github.io/quax.

Example: LoRA

This example demonstrates everything you need to use the built-in quax.examples.lora library.

import equinox as eqx
import jax.random as jr
import quax
import quax.examples.lora as lora

#
# Start off with any JAX program: here, the forward pass through a linear layer.
#

key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
linear = eqx.nn.Linear(10, 12, key=key1)
vector = jr.normal(key2, (10,))

def run(model, x):
  return model(x)

run(linear, vector)  # can call this as normal

#
# Now let's Lora-ify it.
#

# Step 1: make the weight be a LoraArray.
lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)
# Step 2: quaxify and call the original function. The transform will call the
# original function, whilst looking up any multiple dispatch rules registered.
# (In this case for doing matmuls against LoraArrays.)
quax.quaxify(run)(lora_linear, vector)
# Appendix: Quax includes a helper to automatically apply Step 1 to all
# `eqx.nn.Linear` layers in a model.
lora_linear = lora.loraify(linear, rank=2, key=key3)

Work in progress!

Right now, the following are not supported:

  • jax.custom_vjp

It should be fairly straightforward to add support for these; open an issue or pull request. (We've already got jax.custom_jvp, jax.lax.cond_p, jax.lax.while_p, and jax.lax.scan_p. :) )

See also: other libraries in the JAX ecosystem

Always useful
Equinox: neural networks and everything not already in core JAX!
jaxtyping: type annotations for shape/dtype of arrays.

Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).

Scientific computing
Diffrax: numerical differential equation solvers.
Optimistix: root finding, minimisation, fixed points, and least squares.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
PySR: symbolic regression. (Non-JAX honourable mention!)

Built on Quax
Quaxed: a namespace of already-wrapped quaxify(jnp.foo) operations.
unxt: Unitful Quantities.

Awesome JAX
Awesome JAX: a longer list of other JAX projects.

Acknowledgements

Significantly inspired by https://github.com/davisyoshida/qax, https://github.com/stanford-crfm/levanter, and jax.experimental.sparse.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

quax-0.3.0.tar.gz (168.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

quax-0.3.0-py3-none-any.whl (39.2 kB view details)

Uploaded Python 3

File details

Details for the file quax-0.3.0.tar.gz.

File metadata

  • Download URL: quax-0.3.0.tar.gz
  • Upload date:
  • Size: 168.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for quax-0.3.0.tar.gz
Algorithm Hash digest
SHA256 269255967890bd169973a9f1fed34cdd69f5997df104d0b1edf1ec0af4bf881b
MD5 b168841db4a6256cc34a34ceb598d93c
BLAKE2b-256 6641a4f6486c2f55a3018745c9c59c27ae5e0a05e226fdf75adaa9820da8028f

See more details on using hashes here.

Provenance

The following attestation bundles were made for quax-0.3.0.tar.gz:

Publisher: cd.yml on nstarman/quax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file quax-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: quax-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 39.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for quax-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 76827c1d81e1b2d311f05f7a4d1e5aee15ae3c519bea046f0f5d9586104d114e
MD5 87e660118af87a5546f8ba9301312c24
BLAKE2b-256 e9dca70508165b6bd6d8946cb849339de9024d09b849b3aa754b1f939935db47

See more details on using hashes here.

Provenance

The following attestation bundles were made for quax-0.3.0-py3-none-any.whl:

Publisher: cd.yml on nstarman/quax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page