Skip to main content

Linear solvers in JAX and Equinox.

Project description

Lineax

Lineax is a JAX library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)

Features include:

  • PyTree-valued matrices and vectors;
  • General linear operators for Jacobians, transposes, etc.;
  • Efficient linear least squares (e.g. QR solvers);
  • Numerically stable gradients through linear least squares;
  • Support for structured (e.g. symmetric) matrices;
  • Improved compilation times;
  • Improved runtime of some algorithms;
  • Support for both real-valued and complex-valued inputs;
  • All the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support, etc.

Installation

pip install lineax

Requires Python 3.9+, JAX 0.4.13+, and Equinox 0.11.0+.

Documentation

Available at https://docs.kidger.site/lineax.

Quick examples

Lineax can solve a least squares problem with an explicit matrix operator:

import jax.random as jr
import lineax as lx

matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 8))
vector = jr.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, solver=lx.QR())

or Lineax can solve a problem without ever materializing a matrix, as done in this quadratic solve:

import jax
import lineax as lx

key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))

def quadratic_fn(y, args):
  return jax.numpy.sum((y - 1)**2)

gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value

Citation

If you found this library to be useful in academic work, then please cite: (arXiv link)

@article{lineax2023,
    title={Lineax: unified linear solves and linear least-squares in JAX and Equinox},
    author={Jason Rader and Terry Lyons and Patrick Kidger},
    journal={
        AI for science workshop at Neural Information Processing Systems 2023,
        arXiv:2311.17283
    },
    year={2023},
}

(Also consider starring the project on GitHub.)

Finally

See also: other libraries in the JAX ecosystem

jaxtyping: type annotations for shape/dtype of arrays.

Equinox: neural networks.

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Diffrax: numerical differential equation solvers.

Optimistix: root finding, minimisation, fixed points, and least squares.

BlackJAX: probabilistic+Bayesian sampling.

Orbax: checkpointing (async/multi-host/multi-device).

sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.

Eqxvision: computer vision models.

Levanter: scalable+reliable training of foundation models (e.g. LLMs).

PySR: symbolic regression. (Non-JAX honourable mention!)

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lineax-0.0.5.tar.gz (44.0 kB view details)

Uploaded Source

Built Distribution

lineax-0.0.5-py3-none-any.whl (66.5 kB view details)

Uploaded Python 3

File details

Details for the file lineax-0.0.5.tar.gz.

File metadata

  • Download URL: lineax-0.0.5.tar.gz
  • Upload date:
  • Size: 44.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for lineax-0.0.5.tar.gz
Algorithm Hash digest
SHA256 ed41098dbd94b639287f15682e19cd9ec1dc9ecad6844ee861536044a0e3285c
MD5 a846fe3da49b5c5f9c76abc8b4b1d892
BLAKE2b-256 7aaf5cd12d7613e4a5d9c4a84acc0ea3392c4467644324796776334a4bbae4b4

See more details on using hashes here.

File details

Details for the file lineax-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: lineax-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 66.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for lineax-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 601bcd121b32707642a6bae65188d5925ff6f2c3c46a99954592a6e8be82ac63
MD5 fad478b9c164586c3b0314786bd45c75
BLAKE2b-256 c05700e0a61a1c4aada93377ef5b7a39965429c04c7283135b8611c1de594247

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page