Linear solvers in JAX and Equinox.
Project description
Lineax
Lineax is a JAX library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)
Features include:
- PyTree-valued matrices and vectors;
- General linear operators for Jacobians, transposes, etc.;
- Efficient linear least squares (e.g. QR solvers);
- Numerically stable gradients through linear least squares;
- Support for structured (e.g. symmetric) matrices;
- Improved compilation times;
- Improved runtime of some algorithms;
- All the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support etc.
Installation
pip install lineax
Requires Python 3.9+, JAX 0.4.13+, and Equinox 0.11.0+.
Documentation
Available at https://docs.kidger.site/lineax.
Quick examples
Lineax can solve a least squares problem with an explicit matrix operator:
import jax.random as jr
import lineax as lx
matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 8))
vector = jr.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, solver=lx.QR())
or Lineax can solve a problem without ever materializing a matrix, as done in this quadratic solve:
import jax
import lineax as lx
key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))
def quadratic_fn(y, args):
return jax.numpy.sum((y - 1)**2)
gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value
Citation
If you found this library to be useful in academic work, then please cite: (arXiv link)
@article{lineax2023,
title={Lineax: unified linear solves and linear least-squares in JAX and Equinox},
author={Jason Rader and Terry Lyons and Patrick Kidger},
journal={
AI for science workshop at Neural Information Processing Systems 2023,
arXiv:2311.17283
},
year={2023},
}
(Also consider starring the project on GitHub.)
Finally
See also: other libraries in the JAX ecosystem
jaxtyping: type annotations for shape/dtype of arrays.
Equinox: neural networks.
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Diffrax: numerical differential equation solvers.
Optimistix: root finding, minimisation, fixed points, and least squares.
BlackJAX: probabilistic+Bayesian sampling.
Orbax: checkpointing (async/multi-host/multi-device).
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
Eqxvision: computer vision models.
Levanter: scalable+reliable training of foundation models (e.g. LLMs).
PySR: symbolic regression. (Non-JAX honourable mention!)
Disclaimer
This is not an official Google product.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file lineax-0.0.4.tar.gz
.
File metadata
- Download URL: lineax-0.0.4.tar.gz
- Upload date:
- Size: 43.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e68f1eba2f352122fdce9adc0556684f31eb8364b1a00acee484dd6e44a34e5e |
|
MD5 | c48be8e7b636e07739d05750de1fa8cb |
|
BLAKE2b-256 | 9424eea20c7812c2fa9662b5a722be16bcfd4b7326bbe8814c8720045a8cd856 |
File details
Details for the file lineax-0.0.4-py3-none-any.whl
.
File metadata
- Download URL: lineax-0.0.4-py3-none-any.whl
- Upload date:
- Size: 65.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 284ae4b6fff3f291cefa675d5cc059b9338d5ee740df0585b92d534b59213248 |
|
MD5 | ce6f847f5fb5416c0bddc0d9fd44a0d9 |
|
BLAKE2b-256 | 9775bbb723b5dc5b1fbda1690d48b3ee13beb5a415f71981412798dec58354ee |