Skip to main content

JAX Algebraic Multigrid Solvers in Python

Project description

AMJax

AMJax bridges PyAMG and JAX for algebraic multigrid (AMG) solvers: it converts PyAMG-constructed hierarchies into jax.{jit,grad,vmap}-compatible, multi-level solvers and preconditioners for large sparse linear systems.

Installation

Install directly from GitHub (PyPI release coming soon):

uv add git+https://github.com/vboussange/AMJax.git

Usage

Direct solve

import pyamg
import jax
import jax.numpy as jnp
from amjax import AMJAXSolver

A  = pyamg.gallery.poisson((100, 100), format="csr")
b  = jnp.ones(A.shape[0])

ml = AMJAXSolver.from_pyamg(pyamg.ruge_stuben_solver(A))

solve = jax.jit(lambda b: ml.solve(b, tol=1e-10, maxiter=100))
x = solve(b)

Preconditioning

AMJAXSolver exposes a preconditioner compatible with any JAX Krylov solver:

from jax.experimental import sparse as jsparse

A_jax = jsparse.BCOO.from_scipy_sparse(A)
M = ml.aspreconditioner(cycle='V')

x, info = jax.scipy.sparse.linalg.cg(A_jax, b, M=M, tol=1e-10, maxiter=30)

Batched solve with jax.vmap

import numpy as np

B = jnp.array(np.random.rand(4, A.shape[0]))  # (n_rhs, n)
solve_batch = jax.jit(jax.vmap(lambda b: ml.solve(b, tol=1e-8, maxiter=100)))
X = solve_batch(B)

Differentiating through the solve with jax.grad

f = lambda b: jnp.sum(ml.solve(b, tol=1e-10, maxiter=100))
grad = jax.grad(f)(b)

Differentiation with preconditioning

f = lambda b: jnp.sum(jax.scipy.sparse.linalg.cg(A_jax, b, M=M, tol=1e-10)[0])
grad = jax.grad(f)(b)

Features

  • V, W and F cycles compiled with jax.jit
  • Coarse solvers: jacobi, lu, qr, pinv
  • Smoothers: jacobi
  • AMG preconditioning for JAX Krylov solvers (e.g. jax.scipy.sparse.linalg.cg)
  • jax.vmap support for batched right-hand sides
  • jax.grad support through both direct solve and preconditioned Krylov solvers

Solvers

AMJAXSolver.from_pyamg accepts any hierarchy produced by a PyAMG factory:

Factory Intended for
pyamg.smoothed_aggregation_solver SPD systems, standard aggregation AMG
pyamg.rootnode_solver SPD systems, robust for anisotropic problems
pyamg.pairwise_solver SPD systems, fast setup, weaker convergence
pyamg.ruge_stuben_solver General SPD systems, classical C/F splitting
pyamg.air_solver Non-symmetric systems

Current limitations: V-cycle only. jacobi coarse solver only.

Benchmark

An exhaustive benchmark can be run in Colab: Open In Colab

Some key insights on speedup gains vs PyAMG-based counterpart:

Scenario Method CPU GPU
Single solve ($Ax=b$, $b \in \mathbb{R}^n$) AMJax - ~16×
Single solve ($Ax=b$, $b \in \mathbb{R}^n$) AMJax + CG - ~17×
Batched solve ($AX=B$, $B \in \mathbb{R}^{n \times K}$, $K=64$, jax.vmap) AMJax 0.7× ~21×
Batched solve ($AX=B$, $B \in \mathbb{R}^{n \times K}$, $K=64$, jax.vmap) AMJax + CG - ~23×

Settings: Ruge-Stüben hierarchy, V-cycle, Jacobi smoother, pinv coarse solver, $n = 1{,}000$, f64, rtol $= 10^{-10}$, max 100 iterations. JAX times exclude JIT compilation. GPU speedup is relative to the PyAMG CPU counterpart.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

amjax-0.0.1.tar.gz (23.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

amjax-0.0.1-py3-none-any.whl (14.5 kB view details)

Uploaded Python 3

File details

Details for the file amjax-0.0.1.tar.gz.

File metadata

  • Download URL: amjax-0.0.1.tar.gz
  • Upload date:
  • Size: 23.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.15

File hashes

Hashes for amjax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 7c8c3b137c07561b5e0dfae94c6ae07537f84ea61b49400bd6d1b40dccebb26a
MD5 5c2a254c0f5552db058457449d11c6c7
BLAKE2b-256 11eec7db35d7cec48fec15e8ed0f1bc4c48e9d9022818ea78d916cd80c0f601f

See more details on using hashes here.

File details

Details for the file amjax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: amjax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 14.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.15

File hashes

Hashes for amjax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 157898b6df94c96228159ef16c75bdbe5a5a919fcd3ede7d58ca49bd5826bcc9
MD5 221fbe3123a206f9cf1b6626147cd356
BLAKE2b-256 85e4eff5fbf7e9b0b9fa0ebf45a18f6893bae4acb61f2e175680d7129e348266

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page