JAX backend for Apple M series of chips
Project description
jax-metallib
A PJRT plugin that enables JAX to run on Apple Metal (MPS) GPUs on Apple Silicon. It compiles StableHLO IR to Metal shaders via MPSGraph, giving JAX programs GPU acceleration on M-series Macs.
Status: Alpha (v0.9.1) -- API and op coverage are evolving.
Requirements
- macOS 13+ on Apple Silicon (M1 / M2 / M3 / M4)
- Python 3.11+
jaxandjaxlib0.9.x
Install
pip install jax-metallib
Verify:
JAX_PLATFORMS=mps python -c "import jax; print(jax.devices())"
# [MpsDevice(id=0)]
Build from source
brew install cmake ninja
git clone https://github.com/erfanzar/jax-metallib.git
cd jax-metallib
uv sync --all-groups
uv pip install -e . # auto-bootstraps native deps on first build (~30 min)
To skip the automatic dependency bootstrap (if you manage deps yourself):
CMAKE_ARGS="-DJAX_SILICON_AUTO_SETUP_DEPS=OFF" uv pip install -e .
Native dependencies
scripts/setup_deps.sh fetches and builds:
| Dependency | Version / Commit |
|---|---|
| LLVM + MLIR | XLA pin bb760b0 |
| StableHLO | 127d2f2 |
| Abseil | 20250127.0 |
| Protobuf | 29.3 |
These are installed to ~/.local/jax-silicon-deps by default.
Usage
The plugin registers as the mps platform in JAX:
import jax
import jax.numpy as jnp
# With JAX_PLATFORMS=mps (or setting jax.config)
x = jnp.ones((1024, 1024))
y = x @ x # runs on Metal GPU
Environment variables
| Variable | Description |
|---|---|
JAX_PLATFORMS=mps |
Select the MPS backend |
JAX_SILICON_LIBRARY_PATH |
Override path to libpjrt_plugin_silicon.dylib |
JAX_MPS_LIBRARY_PATH |
Legacy alias for the above |
MPS_LOG_LEVEL=0..3 |
Logging verbosity (0=error, 1=warn, 2=info, 3=debug) |
Supported operations
100+ StableHLO operations are implemented across these categories:
| Category | Examples |
|---|---|
| Unary | tanh, exp, log, sin, cos, sqrt, erf, abs, sign |
| Binary | add, subtract, multiply, divide, dot, comparisons |
| Reductions | reduce_sum, reduce_max, reduce_min, argmax, argmin |
| Shape | reshape, transpose, slice, pad, concatenate, gather, scatter |
| Convolution | conv_general_dilated with arbitrary padding/dilation |
| Linear algebra | matmul, cholesky, qr, svd, triangular_solve |
| FFT | fft, rfft, ifft, irfft |
| Random | Threefry / Philox RNG |
| Sort | sort, top_k |
| Control flow | cond (if/else), while_loop, scan |
| Bitwise | and, or, xor, shift_left, shift_right |
Encountering an unsupported op prints a diagnostic with a link to file a feature request.
Testing
uv run pytest # compare CPU vs MPS (default)
JAX_TEST_MODE=mps uv run pytest # MPS only
JAX_TEST_MODE=cpu uv run pytest # CPU only
The test suite covers value correctness, gradient accuracy, and includes integration tests with Flax and NumPyro.
How it works
JAX Python code
|
StableHLO IR (MLIR)
|
stablehlo_parser.mm -- parses IR, looks up ops in the registry
|
MPSGraph operations -- builds a Metal compute graph
|
Metal command buffer -- compiled & dispatched to GPU
|
Device memory result
The PJRT C API (pjrt_api.cc) exposes client, device, buffer, and executable
abstractions that JAX expects. Each StableHLO op is registered in
src/pjrt_plugin/ops/ and mapped to the corresponding MPSGraph method.
Repository layout
src/
jax_plugins/silicon/ Python entrypoint (plugin registration)
pjrt_plugin/ C++/Obj-C++ PJRT backend
ops/ Op implementations (~100+ ops)
stablehlo_parser.mm StableHLO IR -> MPSGraph compiler
mps_client.mm Metal device & command queue management
mps_executable.mm Executable compilation & dispatch
tests/
test_ops.py Parametrized op tests
configs/ Per-category test configurations
scripts/
setup_deps.sh One-time native dependency bootstrap
release.sh Release automation
Contributing
See CONTRIBUTING.md. The short version:
brew install cmake ninja && ./scripts/setup_deps.sh && uv sync --all-groupsuv pip install -e . && pre-commit install- Add your op in
src/pjrt_plugin/ops/, add a test intests/configs/, rebuild, and runuv run pytest.
License
Apache-2.0. See LICENSE.
This project is a derivative of tillahoffmann/jax-mps.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_metallib-0.9.4.tar.gz.
File metadata
- Download URL: jax_metallib-0.9.4.tar.gz
- Upload date:
- Size: 257.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1a851486867c291bfa02baa42b02f2c44286a8e475f557636b4d5b61bc1fc4b5
|
|
| MD5 |
aa8dd868fc502e424e39a7578988e6b4
|
|
| BLAKE2b-256 |
4d1b0a91b264690cb7b2caa4673099ae19a4e8e152385557fc7a5e962abde673
|
File details
Details for the file jax_metallib-0.9.4-cp313-cp313-macosx_26_0_arm64.whl.
File metadata
- Download URL: jax_metallib-0.9.4-cp313-cp313-macosx_26_0_arm64.whl
- Upload date:
- Size: 6.8 MB
- Tags: CPython 3.13, macOS 26.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
498c81fd07f0718cb0dcd7a3224af3ade973cf00464c5109f00bd17481f22f03
|
|
| MD5 |
6ce9610cb195ba43e410e1098276e493
|
|
| BLAKE2b-256 |
75a31e00232b1bc3e5c777af1389ca5e5b8799f5de309b6b5aa0518be6c17fb3
|