Skip to main content

Linear solvers in JAX and Equinox.

Project description

Lineax

Lineax is a JAX library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)

Features include:

  • PyTree-valued matrices and vectors;
  • General linear operators for Jacobians, transposes, etc.;
  • Efficient linear least squares (e.g. QR solvers);
  • Numerically stable gradients through linear least squares;
  • Support for structured (e.g. symmetric) matrices;
  • Improved compilation times;
  • Improved runtime of some algorithms;
  • All the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support etc.

Installation

pip install lineax

Requires Python 3.9+, JAX 0.4.13+, and Equinox 0.11.0+.

Documentation

Available at https://docs.kidger.site/lineax.

Quick examples

Lineax can solve a least squares problem with an explicit matrix operator:

import jax.random as jr
import lineax as lx

matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 8))
vector = jr.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, solver=lx.QR())

or Lineax can solve a problem without ever materializing a matrix, as done in this quadratic solve:

import jax
import lineax as lx

key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))

def quadratic_fn(y, args):
  return jax.numpy.sum((y - 1)**2)

gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value

Finally

See also: other libraries in the JAX ecosystem

jaxtyping: type annotations for shape/dtype of arrays.

Equinox: neural networks.

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Diffrax: numerical differential equation solvers.

Optimistix: root finding, minimisation, fixed points, and least squares.

BlackJAX: probabilistic+Bayesian sampling.

Orbax: checkpointing (async/multi-host/multi-device).

sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.

Eqxvision: computer vision models.

Levanter: scalable+reliable training of foundation models (e.g. LLMs).

PySR: symbolic regression. (Non-JAX honourable mention!)

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lineax-0.0.3.tar.gz (41.1 kB view details)

Uploaded Source

Built Distribution

lineax-0.0.3-py3-none-any.whl (62.7 kB view details)

Uploaded Python 3

File details

Details for the file lineax-0.0.3.tar.gz.

File metadata

  • Download URL: lineax-0.0.3.tar.gz
  • Upload date:
  • Size: 41.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for lineax-0.0.3.tar.gz
Algorithm Hash digest
SHA256 58e3ff9cd82d92a606f46088d80d5c45258d33a7f8eaef717251829230523a0f
MD5 756f8fe3ec52a0ef4af6c9287a6a3fca
BLAKE2b-256 6fbb686599947ddd873fc263680bcf7a5f44028a4a918ae551df71a6e6194e51

See more details on using hashes here.

File details

Details for the file lineax-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: lineax-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 62.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for lineax-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 94c0975abe6cb9a98d6e3218c8d631f2e94814e2faa139ccadf3bf3cf25bb9c5
MD5 beb14bf03c754ab546e41f0c3c8c44e7
BLAKE2b-256 985d032890a139fb8c3c9d3d316401407e0d831b0392037270bb6444dfed7642

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page