Multiple dispatch in JAX via custom interpreters.
Project description
Quax
JAX + multiple dispatch + custom array-ish objects, e.g.:
- LoRA weight matrices
- symbolic zeros
- arrays with named dimensions
- structured (e.g. tridiagonal) matrices
- sparse arrays
- quantised arrays
- arrays with physical units attached
- etc! (See the built-in
quax.examples
library for most of the above!)
For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul (W + AB)v
into the more-efficient Wv + ABv
.
This works via a custom JAX transform. Take an existing JAX program, wrap it in a quax.quaxify
, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!
(Just like how jax.vmap
takes a program, but reinterprets each operation as its batched version, so to will quax.quaxify
take a program and reinterpret each operation according to what array-ish types are passed.)
Installation
pip install quax
Documentation
Available at https://docs.kidger.site/quax.
Example: LoRA
This example demonstrates everything you need to use the built-in quax.examples.lora
library.
--8<-- ".lora-example.md"
Work in progress!
Right now, the following are not supported:
- Control flow primitives (e.g.
jax.lax.cond
). jax.custom_vjp
It should be fairly straightforward to add support for these; open an issue or pull request.
See also: other libraries in the JAX ecosystem
Equinox: neural networks.
jaxtyping: type annotations for shape/dtype of arrays.
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Diffrax: numerical differential equation solvers.
Optimistix: root finding, minimisation, fixed points, and least squares.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
Orbax: checkpointing (async/multi-host/multi-device).
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
Eqxvision: computer vision models.
Levanter: scalable+reliable training of foundation models (e.g. LLMs).
PySR: symbolic regression. (Non-JAX honourable mention!)
Acknowledgements
Significantly inspired by https://github.com/davisyoshida/qax, https://github.com/stanford-crfm/levanter, and jax.experimental.sparse
.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file quax-0.0.3.tar.gz
.
File metadata
- Download URL: quax-0.0.3.tar.gz
- Upload date:
- Size: 23.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fced5e641ee6b9910bbfe45227cbebe3c8db7f18653f78e455b34bf0fa76fbd4 |
|
MD5 | 1b98e3744cb5b075138e7f42f2602dbf |
|
BLAKE2b-256 | c4dd7f8224402b3ffcf019cbb7620db77a485978101e63c05302d9fd742c90e8 |
File details
Details for the file quax-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: quax-0.0.3-py3-none-any.whl
- Upload date:
- Size: 33.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d154f7b1446c37cf7a52f900f548a3cf0d9bba0fc02e1d6f0dbb4457d4d2923e |
|
MD5 | 8b46f00f221ae8e5cc7f0d54b8a3d46c |
|
BLAKE2b-256 | 1b9f0ae645a6f6cc620106c8ab5bef75fadccea9a8c60535eb2920a4ae862c2d |