Skip to main content

JAX bindings for NVIDIA AmgX sparse linear solver with automatic differentiation support

Project description

JAX-AMG

Docs License: Apache 2.0 Python arXiv

JAX-AMG brings the power of NVIDIA's AmgX library to the JAX ecosystem, providing high-performance, GPU-accelerated sparse linear solvers with full support for automatic differentiation.

Documentation: https://jx-wang-s-group.github.io/JAX-AMG/

Features

  • GPU-Accelerated Solvers: Leverages NVIDIA AmgX for a broad range of GPU-accelerated sparse linear solvers, including algebraic multigrid (AMG), Krylov methods, and various variants, with flexible configuraiton options for solvers, smoothers, and preconditioners.
  • Automatic Differentiation: Supports adjoint-based gradient computation and integrates seamlessly with JAX for end-to-end differentiable workflows.
  • JIT Compilation: Built as a native JAX primitive, fully compatible with Just-in-Time compilation (jax.jit) for efficient, low-overhead execution.
  • MPI Support: Enables distributed linear solves across multiple GPUs, with GPU-aware MPI support.
  • Matrix-Free Operators: Beyond explicit matrices, A can be a callable operator. The library recovers the exact sparsity pattern in a single pass by tracing the operator's computation graph, then assembles the matrix the solver needs.

Prerequisites

  • Python 3.10+
  • JAX 0.5.0+ with CUDA support
  • AmgX 2.5.0+
  • CUDA Toolkit 12.0+

Additional for Distributed (MPI) Mode:

  • MPI library (e.g., OpenMPI, MPICH)
  • CUDA-aware MPI (optional, for GPU-direct communication)

Installation

JAX-AMG is installed with pip. It compiles a native extension against a CUDA toolkit and a source build of NVIDIA AmgX, so set CUDA_HOME and AMGX_ROOT first, then run the command for your CUDA version:

pip install "jaxamg[cuda12]"   # or jaxamg[cuda13]

At runtime, add the AmgX and CUDA libraries to your library path:

export LD_LIBRARY_PATH=$AMGX_ROOT/build:$CUDA_HOME/lib64:$LD_LIBRARY_PATH

For distributed (MPI) mode, the install script, conda, or building from source, see the full Installation Guide.


Quick Start

A simple tridiagonal system can be solved as:

import jaxamg
from jaxamg.matrices import tridiagonal_matrix, rhs_ones

# Create a simple tridiagonal system
n = 100
A = tridiagonal_matrix(n, diagonal_value=2.0)
b = rhs_ones(n)

# Solve Ax = b
x, info = jaxamg.solve(A, b)

MPI Distributed Solving

A distributed 2D Poisson system can be solved with GPU-aware MPI as:

from mpi4py import MPI
import jaxamg
from jaxamg.mpi_utils import partition_vector, gather_solution
from jaxamg.matrices import poisson_matrix_distributed, rhs_ones

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
nranks = comm.Get_size()

# Create distributed 2D Poisson matrix
n = 16
A_local, row_start, row_end = poisson_matrix_distributed(n, n, rank, nranks)
b_local, _, _ = partition_vector(rhs_ones(n * n), rank, nranks)

# Solve in distributed mode
x_local, info = jaxamg.solve(
    A_local, b_local,
    comm=comm,
    nglobal=n * n,
    partition_info=(row_start, row_end),
    config={
        "solver": "CG",
        "preconditioner": {"solver": "JACOBI_L1"},
        "communicator": "MPI_DIRECT",
    }
)

# Gather solution at root rank
x_global = gather_solution(x_local, comm, root=0)
if rank == 0: print(x_global)

Citation

If you use JAX-AMG in your work, please consider using the following citation (arXiv:2606.09001):

@misc{jaxamg2026,
      title={JAX-AMG: A GPU-Accelerated Differentiable Sparse Linear Solver Library for JAX},
      author={Yi Liu and Xiantao Fan and Jian-Xun Wang},
      year={2026},
      eprint={2606.09001},
      archivePrefix={arXiv},
      primaryClass={cs.MS},
      url={https://arxiv.org/abs/2606.09001},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxamg-0.1.0.tar.gz (80.7 kB view details)

Uploaded Source

File details

Details for the file jaxamg-0.1.0.tar.gz.

File metadata

  • Download URL: jaxamg-0.1.0.tar.gz
  • Upload date:
  • Size: 80.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxamg-0.1.0.tar.gz
Algorithm Hash digest
SHA256 12515233529d4f0083fbdedd039e9b7c162fe35361f677cbe18f3c599056d049
MD5 27f474ee6c74b95cf522859302ee2ce3
BLAKE2b-256 af934173cf8163324891b1b2d1d0a8df1f9e740b4fddea69243eabfd8f321afd

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxamg-0.1.0.tar.gz:

Publisher: publish.yml on jx-wang-s-group/JAX-AMG

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page