JAX Algebraic Multigrid Solvers in Python
Project description
AMJax
AMJax bridges PyAMG and JAX for algebraic multigrid (AMG) solvers: it converts PyAMG-constructed hierarchies into jax.{jit,grad,vmap}-compatible, multi-level solvers and preconditioners for large sparse linear systems.
Installation
Install directly from GitHub (PyPI release coming soon):
uv add git+https://github.com/vboussange/AMJax.git
Usage
Direct solve
import pyamg
import jax
import jax.numpy as jnp
from amjax import AMJAXSolver
A = pyamg.gallery.poisson((100, 100), format="csr")
b = jnp.ones(A.shape[0])
ml = AMJAXSolver.from_pyamg(pyamg.ruge_stuben_solver(A))
solve = jax.jit(lambda b: ml.solve(b, tol=1e-10, maxiter=100))
x = solve(b)
Preconditioning
AMJAXSolver exposes a preconditioner compatible with any JAX Krylov solver:
from jax.experimental import sparse as jsparse
A_jax = jsparse.BCOO.from_scipy_sparse(A)
M = ml.aspreconditioner(cycle='V')
x, info = jax.scipy.sparse.linalg.cg(A_jax, b, M=M, tol=1e-10, maxiter=30)
Batched solve with jax.vmap
import numpy as np
B = jnp.array(np.random.rand(4, A.shape[0])) # (n_rhs, n)
solve_batch = jax.jit(jax.vmap(lambda b: ml.solve(b, tol=1e-8, maxiter=100)))
X = solve_batch(B)
Differentiating through the solve with jax.grad
f = lambda b: jnp.sum(ml.solve(b, tol=1e-10, maxiter=100))
grad = jax.grad(f)(b)
Differentiation with preconditioning
f = lambda b: jnp.sum(jax.scipy.sparse.linalg.cg(A_jax, b, M=M, tol=1e-10)[0])
grad = jax.grad(f)(b)
Features
- V, W and F cycles compiled with
jax.jit - Coarse solvers:
jacobi,lu,qr,pinv - Smoothers:
jacobi - AMG preconditioning for JAX Krylov solvers (e.g.
jax.scipy.sparse.linalg.cg) jax.vmapsupport for batched right-hand sidesjax.gradsupport through both direct solve and preconditioned Krylov solvers
Solvers
AMJAXSolver.from_pyamg accepts any hierarchy produced by a PyAMG factory:
| Factory | Intended for |
|---|---|
pyamg.smoothed_aggregation_solver |
SPD systems, standard aggregation AMG |
pyamg.rootnode_solver |
SPD systems, robust for anisotropic problems |
pyamg.pairwise_solver |
SPD systems, fast setup, weaker convergence |
pyamg.ruge_stuben_solver |
General SPD systems, classical C/F splitting |
pyamg.air_solver |
Non-symmetric systems |
Current limitations: V-cycle only. jacobi coarse solver only.
Benchmark
An exhaustive benchmark can be run in Colab:
Some key insights on speedup gains vs PyAMG-based counterpart:
| Scenario | Method | CPU | GPU |
|---|---|---|---|
| Single solve ($Ax=b$, $b \in \mathbb{R}^n$) | AMJax | - | ~16× |
| Single solve ($Ax=b$, $b \in \mathbb{R}^n$) | AMJax + CG | - | ~17× |
Batched solve ($AX=B$, $B \in \mathbb{R}^{n \times K}$, $K=64$, jax.vmap) |
AMJax | 0.7× | ~21× |
Batched solve ($AX=B$, $B \in \mathbb{R}^{n \times K}$, $K=64$, jax.vmap) |
AMJax + CG | - | ~23× |
Settings: Ruge-Stüben hierarchy, V-cycle, Jacobi smoother,
pinvcoarse solver, $n = 1{,}000$, f64, rtol $= 10^{-10}$, max 100 iterations. JAX times exclude JIT compilation. GPU speedup is relative to the PyAMG CPU counterpart.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file amjax-0.0.1.tar.gz.
File metadata
- Download URL: amjax-0.0.1.tar.gz
- Upload date:
- Size: 23.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7c8c3b137c07561b5e0dfae94c6ae07537f84ea61b49400bd6d1b40dccebb26a
|
|
| MD5 |
5c2a254c0f5552db058457449d11c6c7
|
|
| BLAKE2b-256 |
11eec7db35d7cec48fec15e8ed0f1bc4c48e9d9022818ea78d916cd80c0f601f
|
File details
Details for the file amjax-0.0.1-py3-none-any.whl.
File metadata
- Download URL: amjax-0.0.1-py3-none-any.whl
- Upload date:
- Size: 14.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
157898b6df94c96228159ef16c75bdbe5a5a919fcd3ede7d58ca49bd5826bcc9
|
|
| MD5 |
221fbe3123a206f9cf1b6626147cd356
|
|
| BLAKE2b-256 |
85e4eff5fbf7e9b0b9fa0ebf45a18f6893bae4acb61f2e175680d7129e348266
|