Skip to main content

JAX Algebraic Multigrid Solvers in Python

Project description

AMJax

AMJax bridges PyAMG and JAX for algebraic multigrid (AMG) solvers: it converts PyAMG-constructed hierarchies into jax.{jit,grad,vmap}-compatible, multi-level solvers and preconditioners for large sparse linear systems.

Installation

uv add amjax

Usage

Direct solve

import pyamg
import jax
import jax.numpy as jnp
from amjax import AMJAXSolver

A  = pyamg.gallery.poisson((100, 100), format="csr")
b  = jnp.ones(A.shape[0])

ml = AMJAXSolver.from_pyamg(pyamg.ruge_stuben_solver(A))

solve = jax.jit(lambda b: ml.solve(b, tol=1e-10, maxiter=100))
x = solve(b)

Preconditioning

AMJAXSolver exposes a preconditioner compatible with any JAX Krylov solver:

from jax.experimental import sparse as jsparse

A_jax = jsparse.BCOO.from_scipy_sparse(A)
M = ml.aspreconditioner(cycle='V')

x, info = jax.scipy.sparse.linalg.cg(A_jax, b, M=M, tol=1e-10, maxiter=30)

Batched solve with jax.vmap

import numpy as np

B = jnp.array(np.random.rand(4, A.shape[0]))  # (n_rhs, n)
solve_batch = jax.jit(jax.vmap(lambda b: ml.solve(b, tol=1e-8, maxiter=100)))
X = solve_batch(B)

Differentiating through the solve with jax.grad

f = lambda b: jnp.sum(ml.solve(b, tol=1e-10, maxiter=100))
grad = jax.grad(f)(b)

Differentiation with preconditioning

f = lambda b: jnp.sum(jax.scipy.sparse.linalg.cg(A_jax, b, M=M, tol=1e-10)[0])
grad = jax.grad(f)(b)

Features

  • V, W and F cycles compiled with jax.jit
  • Coarse solvers: jacobi, lu, qr, pinv
  • Smoothers: jacobi
  • AMG preconditioning for JAX Krylov solvers (e.g. jax.scipy.sparse.linalg.cg)
  • jax.vmap support for batched right-hand sides
  • jax.grad support through both direct solve and preconditioned Krylov solvers

Solvers

AMJAXSolver.from_pyamg accepts any hierarchy produced by a PyAMG factory:

Factory Intended for
pyamg.smoothed_aggregation_solver SPD systems, standard aggregation AMG
pyamg.rootnode_solver SPD systems, robust for anisotropic problems
pyamg.pairwise_solver SPD systems, fast setup, weaker convergence
pyamg.ruge_stuben_solver General SPD systems, classical C/F splitting
pyamg.air_solver Non-symmetric systems

Current limitations: V-cycle only. jacobi coarse solver only.

Benchmark

An exhaustive benchmark can be run in Colab: Open In Colab

Some key insights on speedup gains vs PyAMG-based counterpart:

Scenario Method CPU GPU
Single solve ($Ax=b$, $b \in \mathbb{R}^n$) AMJax - ~16×
Single solve ($Ax=b$, $b \in \mathbb{R}^n$) AMJax + CG - ~17×
Batched solve ($AX=B$, $B \in \mathbb{R}^{n \times K}$, $K=64$, jax.vmap) AMJax 0.7× ~21×
Batched solve ($AX=B$, $B \in \mathbb{R}^{n \times K}$, $K=64$, jax.vmap) AMJax + CG - ~23×

Settings: Ruge-Stüben hierarchy, V-cycle, Jacobi smoother, pinv coarse solver, $n = 1{,}000$, f64, rtol $= 10^{-10}$, max 100 iterations. JAX times exclude JIT compilation. GPU speedup is relative to the PyAMG CPU counterpart.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

amjax-0.0.2.tar.gz (23.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

amjax-0.0.2-py3-none-any.whl (14.4 kB view details)

Uploaded Python 3

File details

Details for the file amjax-0.0.2.tar.gz.

File metadata

  • Download URL: amjax-0.0.2.tar.gz
  • Upload date:
  • Size: 23.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for amjax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 f1cbaef5bf2a0b1f1ca2eaf291509bb14f75311e583cbf440e57e5ccbf4973cd
MD5 ad4612b176c68eeae49aa1eba12c9b47
BLAKE2b-256 5d98675952d0d5c3e9827816b75b87645ad47b1bf94b68c6516c0e8c4a4ee422

See more details on using hashes here.

File details

Details for the file amjax-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: amjax-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 14.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for amjax-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c6550a69f6e9d369b79b1d7c663c7111773189b9f2e6c3b651e6f5d6a6905810
MD5 275e681705b29561a816faec711a37f6
BLAKE2b-256 592767fd2e3cef01ef81e040b09a73fe302fe079b5044d41451cfde7cde5fc0b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page