Skip to main content

Linear solvers in JAX and Equinox.

Project description

Lineax

Lineax is a JAX library for linear solves and linear least squares. That is, Lineax provides routines that solve for $x$ in $Ax = b$. (Even when $A$ may be ill-posed or rectangular.)

Features include:

  • PyTree-valued matrices and vectors;
  • General linear operators for Jacobians, transposes, etc.;
  • Efficient linear least squares (e.g. QR solvers);
  • Numerically stable gradients through linear least squares;
  • Support for structured (e.g. symmetric) matrices;
  • Improved compilation times;
  • Improved runtime of some algorithms;
  • All the benefits of working with JAX: autodiff, autoparallelism, GPU/TPU support etc.

Installation

pip install lineax

Requires Python 3.9+, JAX 0.4.13+, and Equinox 0.11.0+.

Documentation

Available at https://docs.kidger.site/lineax.

Quick examples

Lineax can solve a least squares problem with an explicit matrix operator:

import jax.random as jr
import lineax as lx

matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 8))
vector = jr.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector, solver=lx.QR())

or Lineax can solve a problem without ever materializing a matrix, as done in this quadratic solve:

import jax
import lineax as lx

key = jax.random.PRNGKey(0)
y = jax.random.normal(key, (10,))

def quadratic_fn(y, args):
  return jax.numpy.sum((y - 1)**2)

gradient_fn = jax.grad(quadratic_fn)
hessian = lx.JacobianLinearOperator(gradient_fn, y, tags=lx.positive_semidefinite_tag)
solver = lx.CG(rtol=1e-6, atol=1e-6)
out = lx.linear_solve(hessian, gradient_fn(y, args=None), solver)
minimum = y - out.value

Finally

See also: other libraries in the JAX ecosystem

Equinox: neural networks.

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Diffrax: numerical differential equation solvers.

jaxtyping: type annotations for shape/dtype of arrays.

Eqxvision: computer vision models.

sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.

Levanter: scalable+reliable training of foundation models (e.g. LLMs).

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lineax-0.0.2.tar.gz (40.7 kB view details)

Uploaded Source

Built Distribution

lineax-0.0.2-py3-none-any.whl (62.3 kB view details)

Uploaded Python 3

File details

Details for the file lineax-0.0.2.tar.gz.

File metadata

  • Download URL: lineax-0.0.2.tar.gz
  • Upload date:
  • Size: 40.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for lineax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 51163e8411d64687fd4c8213f9aeee16e5b8e44551bb2ab924311f174b4e3e57
MD5 7cb98d6599823f186681e50d1598a22d
BLAKE2b-256 7fcbc93c416bc1433064dc5ab71cb9bafb8736ca7887aec56019b274d2e4a1df

See more details on using hashes here.

File details

Details for the file lineax-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: lineax-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 62.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for lineax-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 6088570d7a5114374e18534a6bb9aecf96f7caf94d9aac9012391d94f8b6d9ba
MD5 fc94a532ed84aea0d0a4d47d0fec7e0b
BLAKE2b-256 72358425713aaf4fcb5715ac3fa3d6cbe61e229b005181c431b647cb215c1ee9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page