Skip to main content

Multiple dispatch in JAX via custom interpreters.

Project description

Quax

JAX + multiple dispatch + custom array-ish objects, e.g.:

  • LoRA weight matrices
  • symbolic zeros
  • arrays with named dimensions
  • structured (e.g. tridiagonal) matrices
  • sparse arrays
  • quantised arrays
  • arrays with physical units attached
  • etc! (See the built-in quax.examples library for most of the above!)

For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul (W + AB)v into the more-efficient Wv + ABv.

This works via a custom JAX transform. Take an existing JAX program, wrap it in a quax.quaxify, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!

(Just like how jax.vmap takes a program, but reinterprets each operation as its batched version, so to will quax.quaxify take a program and reinterpret each operation according to what array-ish types are passed.)

Installation

pip install quax

Documentation

Available at https://docs.kidger.site/quax.

Example: LoRA

This example demonstrates everything you need to use the built-in quax.examples.lora library.

--8<-- ".lora-example.md"

Work in progress!

Right now, the following are not supported:

  • Control flow primitives (e.g. jax.lax.cond).
  • jax.custom_vjp

It should be fairly straightforward to add support for these; open an issue or pull request.

See also: other libraries in the JAX ecosystem

Equinox: neural networks.

jaxtyping: type annotations for shape/dtype of arrays.

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Diffrax: numerical differential equation solvers.

Optimistix: root finding, minimisation, fixed points, and least squares.

Lineax: linear solvers.

BlackJAX: probabilistic+Bayesian sampling.

Orbax: checkpointing (async/multi-host/multi-device).

sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.

Eqxvision: computer vision models.

Levanter: scalable+reliable training of foundation models (e.g. LLMs).

PySR: symbolic regression. (Non-JAX honourable mention!)

Acknowledgements

Significantly inspired by https://github.com/davisyoshida/qax, https://github.com/stanford-crfm/levanter, and jax.experimental.sparse.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

quax-0.0.3.tar.gz (23.6 kB view details)

Uploaded Source

Built Distribution

quax-0.0.3-py3-none-any.whl (33.3 kB view details)

Uploaded Python 3

File details

Details for the file quax-0.0.3.tar.gz.

File metadata

  • Download URL: quax-0.0.3.tar.gz
  • Upload date:
  • Size: 23.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for quax-0.0.3.tar.gz
Algorithm Hash digest
SHA256 fced5e641ee6b9910bbfe45227cbebe3c8db7f18653f78e455b34bf0fa76fbd4
MD5 1b98e3744cb5b075138e7f42f2602dbf
BLAKE2b-256 c4dd7f8224402b3ffcf019cbb7620db77a485978101e63c05302d9fd742c90e8

See more details on using hashes here.

File details

Details for the file quax-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: quax-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 33.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for quax-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 d154f7b1446c37cf7a52f900f548a3cf0d9bba0fc02e1d6f0dbb4457d4d2923e
MD5 8b46f00f221ae8e5cc7f0d54b8a3d46c
BLAKE2b-256 1b9f0ae645a6f6cc620106c8ab5bef75fadccea9a8c60535eb2920a4ae862c2d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page