Skip to main content

Linear solvers in JAX and Equinox.

Project description

Lineax

Lineax is a JAX library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)

Features include:

  • PyTree-valued matrices and vectors;
  • General linear operators for Jacobians, transposes, etc.;
  • Efficient linear least squares (e.g. QR solvers);
  • Numerically stable gradients through linear least squares;
  • Support for structured (e.g. symmetric) matrices;
  • Improved compilation times;
  • Improved runtime of some algorithms;
  • Support for both real-valued and complex-valued inputs;
  • All the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support, etc.

Installation

pip install lineax

Requires Python 3.10+, JAX 0.4.38+, and Equinox 0.11.10+.

Documentation

Available at https://docs.kidger.site/lineax.

Quick examples

Lineax can solve a least squares problem with an explicit matrix operator:

import jax.random as jr
import lineax as lx

matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 8))
vector = jr.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, solver=lx.QR())

or Lineax can solve a problem without ever materializing a matrix, as done in this quadratic solve:

import jax
import lineax as lx

key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))

def quadratic_fn(y, args):
  return jax.numpy.sum((y - 1)**2)

gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value

Citation

If you found this library to be useful in academic work, then please cite: (arXiv link)

@article{lineax2023,
    title={Lineax: unified linear solves and linear least-squares in JAX and Equinox},
    author={Jason Rader and Terry Lyons and Patrick Kidger},
    journal={
        AI for science workshop at Neural Information Processing Systems 2023,
        arXiv:2311.17283
    },
    year={2023},
}

(Also consider starring the project on GitHub.)

See also: other libraries in the JAX ecosystem

Always useful
Equinox: neural networks and everything not already in core JAX!
jaxtyping: type annotations for shape/dtype of arrays.

Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).
paramax: parameterizations and constraints for PyTrees.

Scientific computing
Diffrax: numerical differential equation solvers.
Optimistix: root finding, minimisation, fixed points, and least squares.
BlackJAX: probabilistic+Bayesian sampling.
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
PySR: symbolic regression. (Non-JAX honourable mention!)

Awesome JAX
Awesome JAX: a longer list of other JAX projects.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lineax-0.1.0.tar.gz (50.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lineax-0.1.0-py3-none-any.whl (74.6 kB view details)

Uploaded Python 3

File details

Details for the file lineax-0.1.0.tar.gz.

File metadata

  • Download URL: lineax-0.1.0.tar.gz
  • Upload date:
  • Size: 50.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for lineax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 5f1a8f060142af2cdbf7d66b99e8d3071c3aa734b677df6339df4b4c4c0554d2
MD5 1d2177c60193a2005b7bfbe23ca394f2
BLAKE2b-256 35d64e28416a6fe58dd6bc7565b1ffa330f4d0ba7d74212642b1b734c511299e

See more details on using hashes here.

File details

Details for the file lineax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: lineax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 74.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for lineax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f00911c6b07d427c4835db46856970c8348bc82a035b51f4386ad09382af957a
MD5 cd6fe924ee1052d426f95e3b8d17a19a
BLAKE2b-256 800c2ed47112fc1958a0a81c9b015d4e1861953a1ec3a17b081c0180a25ce82c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page