JAX backend for Apple M series of chips
Project description
jax-metallib
Run JAX on Apple Metal GPUs
jax-metallib is a PJRT plugin that enables JAX to run on Apple Metal GPUs. It compiles StableHLO IR to Metal compute kernels via MPSGraph, giving JAX programs native GPU acceleration on M-series Macs — no code changes required.
Highlights
- 120+ StableHLO ops — from basic arithmetic through convolutions, FFTs, linear algebra (Cholesky, QR, SVD), sorting, and control flow
- Drop-in acceleration — set
JAX_PLATFORMS=mpsand existing JAX code runs on the Metal GPU - Full gradient support —
jax.grad,jax.value_and_grad, and higher-order derivatives work out of the box - JIT kernel fusion — consecutive elementwise ops are fused into single Metal Shading Language kernels at runtime
- Native MPS kernels — performance-critical operations (Cholesky decomposition, triangular solve) use MPS native kernels directly, bypassing the graph compiler
- Executable serialization — compiled programs can be serialized and deserialized for AOT compilation and
jax.jitcaching - Framework integration — tested with Flax (NNX) and NumPyro
Quick Start
Install
pip install jax-metallib
Verify
JAX_PLATFORMS=mps python -c "import jax; print(jax.devices())"
Run
import jax
import jax.numpy as jnp
x = jnp.ones((1024, 1024))
y = x @ x
f = jax.jit(jax.grad(lambda x: jnp.sum(jnp.tanh(x))))
grads = f(x)
Requirements
| Requirement | Version |
|---|---|
| macOS | 13.0+ (Ventura or later) |
| Hardware | Apple Silicon (M1 / M2 / M3 / M4 / M5) |
| Python | 3.11, 3.12, or 3.13 |
| JAX | 0.9.x |
| jaxlib | 0.9.x |
Build from Source
Building from source compiles the native Metal plugin (~6 MB shared library). The first build automatically bootstraps LLVM/MLIR and StableHLO dependencies, which takes about 30 minutes.
brew install cmake ninja
git clone https://github.com/erfanzar/jax-metallib.git
cd jax-metallib
uv sync --all-groups
uv pip install -e .
To skip the automatic dependency bootstrap (if you manage LLVM/MLIR yourself):
CMAKE_ARGS="-DJAX_METALLIB_AUTO_SETUP_DEPS=OFF" uv pip install -e .
Native Dependencies
The bootstrap script (scripts/setup_deps.sh) fetches and builds the following into ~/.local/jax-metallib-deps:
| Dependency | Version | Purpose |
|---|---|---|
| LLVM + MLIR | XLA pin bb760b0 |
MLIR infrastructure, StableHLO dialect support |
| StableHLO | 127d2f2 |
IR parsing, serialization, dialect definitions |
| Abseil | 20250127.0 | C++ utilities (strings, status, synchronization) |
| Protobuf | 29.3 | Device assignment protocol buffer serialization |
Supported Operations
122 operations are registered across StableHLO, CHLO, and MHLO dialects:
Unary operations (42 ops)
abs cbrt ceil cosine count_leading_zeros exponential exponential_minus_one erf floor imag is_finite log log_plus_one logistic negate real round_nearest_even rsqrt sign sine sqrt tan tanh
CHLO: acos acosh asin asinh atanh bessel_i1e conj cosh digamma erf_inv erfc is_inf is_neg_inf is_pos_inf lgamma sinh square
Binary operations (15 ops)
add atan2 clamp compare divide dot dot_general maximum minimum multiply power remainder select subtract
CHLO: next_after
Shape & indexing (18 ops)
bitcast_convert broadcast broadcast_in_dim concatenate convert dynamic_broadcast_in_dim dynamic_reshape dynamic_slice dynamic_update_slice gather get_dimension_size pad reshape reverse scatter set_dimension_size slice transpose
Reductions (7 ops)
reduce (sum, product, max, min, and, or, argmax, argmin) reduce_window select_and_scatter batch_norm_inference batch_norm_training batch_norm_grad return
Convolution
convolution — full conv_general_dilated with arbitrary padding, dilation, strides, feature grouping, and batch grouping
Linear algebra
cholesky (native MPS kernel) triangular_solve (native MPS kernel)
Other categories
| Category | Operations |
|---|---|
| Bitwise | and or xor not shift_left shift_right_logical shift_right_arithmetic popcnt |
| FFT | fft (FFT, RFFT, IFFT, IRFFT) |
| Sort | sort top_k |
| Random | rng rng_bit_generator (Threefry / Philox) |
| Tensor creation | constant iota |
| Control flow | while case (if/else) custom_call |
| Collective | all_reduce all_gather reduce_scatter collective_permute |
| Higher-order | map complex |
Encountering an unsupported op prints a diagnostic with a direct link to file a feature request.
Architecture
JAX Python program
│
▼
StableHLO IR (MLIR bytecode) ← jax.jit compiles Python to StableHLO
│
▼
PJRT C API layer ← pjrt_api.cc exposes client/device/buffer/exec
│
▼
StableHLO Parser ← deserializes + inlines MLIR modules
│
▼
Execution Plan Builder ← walks ops, groups into MPSGraph segments
│ + native MPS kernel steps
├──► MPSGraph segments ← op handlers build compute graphs
│ │
│ ▼
│ Metal command buffer ← compiled & dispatched to GPU
│
└──► Native MPS kernels ← direct kernel dispatch (Cholesky, etc.)
│
▼
Device memory (MTLBuffer) ← results flow back to JAX as DeviceArrays
The plugin implements two execution models that are interleaved within a single program:
-
Graph execution — Consecutive ops are batched into an
MPSGraph, compiled to a Metal compute pipeline, and dispatched as a single GPU command. This is the primary path for most operations. -
Native execution — Performance-critical operations (e.g., Cholesky decomposition via
MPSMatrixDecompositionCholesky) bypass the graph compiler and dispatch MPS native kernels directly onMTLBufferobjects. -
JIT fusion — Chains of elementwise operations are detected and fused into custom Metal Shading Language (MSL) kernels at runtime, reducing kernel launch overhead and improving memory bandwidth utilization.
Configuration
Environment Variables
| Variable | Default | Description |
|---|---|---|
JAX_PLATFORMS |
— | Set to mps to select the Metal backend |
JAX_METALLIB_LIBRARY_PATH |
auto | Override path to libpjrt_plugin_silicon.dylib |
MPS_LOG_LEVEL |
1 |
Logging verbosity: 0 error, 1 warn, 2 info, 3 debug |
JAX_TEST_MODE |
compare |
Test mode: compare (CPU vs MPS), mps, or cpu |
Library Discovery
The plugin searches for the native library in this order:
JAX_METALLIB_LIBRARY_PATHenvironment variable- Package directory (editable install)
<package>/lib/(wheel install)build/*/lib/(CMake build directory)/usr/local/lib/,/opt/homebrew/lib/
Testing
The test suite validates numerical correctness by comparing CPU and MPS results with configurable tolerances.
uv run pytest
JAX_TEST_MODE=mps uv run pytest
uv run pytest -k "unary"
uv run pytest -k "linalg"
uv run pytest -k "flax"
Test Coverage
| Category | What's tested |
|---|---|
| Value correctness | CPU vs MPS output comparison for all 122 ops |
| Gradient accuracy | jax.grad / jax.value_and_grad for differentiable ops |
| Edge cases | float16 precision, int64 large values, complex numbers |
| Regressions | Catastrophic cancellation (log1p), erf_inv range accuracy |
| Integration | Flax NNX (Linear, Conv, LayerNorm, MultiHeadAttention) |
| Integration | NumPyro probabilistic programming models |
Quality Gate
Pre-commit hooks enforce the full quality gate on every commit (MPS is unavailable in GitHub Actions):
- clang-format — C/C++/ObjC++ formatting (LLVM style, 110 col)
- ruff — Python formatting and linting
- build — full native library rebuild
- clang-tidy — C++ static analysis (with caching via ctcache)
- pytest — full test suite with op coverage enforcement
Benchmarks
uv run python -m benchmarks.bench list
JAX_PLATFORMS=mps uv run python -m benchmarks.bench run --case 'anchor\..*' --platform mps
uv run python -m benchmarks.bench run --case '.*' --json-out results.jsonl
The benchmark suite includes per-op micro-benchmarks, representative anchor workloads, and competitive benchmarks against other backends.
Repository Layout
src/
jax_plugins/silicon/ Python entrypoint — plugin registration & library discovery
pjrt_plugin/
api/ PJRT C API layer (8 implementation files)
pjrt_api.cc Function pointer table (main entry point)
pjrt_client.cc Client: compile, buffer creation, platform info
pjrt_buffer.cc Buffer: host↔device transfer, copy, clone
pjrt_executable.cc Executable: execute, serialize, output metadata
pjrt_device.cc Device: attributes, memory, description
pjrt_event.cc Event: async completion, create, set
pjrt_memory.cc Memory: kind, addressable devices
pjrt_topology.cc Topology: device descriptions, platform
core/ Backend core (Objective-C++ / Metal)
mps_client.h/.mm Metal device & command queue management
mps_device.h/.mm GPU device abstraction
mps_buffer.h/.mm MTLBuffer wrapper — host copy, clone, blit
mps_executable.h/.mm Execution plan builder & runner
stablehlo_parser.h/.mm MLIR StableHLO deserializer
type_utils.h/.mm MLIR ↔ MPS type conversions
completion_event.h Thread-safe async completion primitives
pjrt_types.h PJRT opaque wrapper structs
logging.h Leveled logging macros
ops/ StableHLO op handlers (122 registrations)
registry.h Op registry, handler types, macros
unary_ops.mm 42 unary operations
binary_ops.mm 15 binary operations
shape_ops.mm 18 shape & indexing operations
reduction_ops.mm 7 reduction operations
convolution_ops.mm General dilated convolution
linalg_ops.mm Cholesky, triangular solve (native MPS)
bitwise_ops.mm 8 bitwise operations
sort_ops.mm Sort, top-k
fft_ops.mm FFT / RFFT / IFFT / IRFFT
control_flow_ops.mm While, case, custom_call
random_ops.mm Threefry / Philox RNG
tensor_creation_ops.mm Constant, iota
collective_ops.mm Single-device collective no-ops
higher_order_ops.mm Map
runtime/ JIT Metal kernel engine
metal_kernels.h/.mm MSL kernel source cache & pipeline compilation
proto/
device_assignment.proto XLA DeviceAssignmentProto definition
tests/
test_ops.py Parametrized test suite (value + gradient)
test_int64_constant_splats.py Regression: int64 precision above 2^53
configs/ Per-category test configurations (15 modules)
benchmarks/
bench.py Benchmark harness (list/run, JSONL output)
bottlenecks.py Per-op micro-benchmarks
vs_jax_mps.py Competitive benchmarks
anchors.py Representative anchor workloads
scripts/
setup_deps.sh One-time native dependency bootstrap
Contributing
See CONTRIBUTING.md for the full guide. The short version:
brew install cmake ninja
./scripts/setup_deps.sh
uv sync --all-groups
uv pip install -e .
pre-commit install
uv pip install -e .
uv run pytest
Acknowledgements
This project draws significant inspiration from MLX by Apple Machine Learning Research. MLX's approach to leveraging Metal and unified memory on Apple Silicon was a major influence on the design and direction of jax-metallib.
License
Copyright 2026 Erfan Zar (@erfanzar). Apache-2.0 — see LICENSE.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_metallib-0.9.4.5.tar.gz.
File metadata
- Download URL: jax_metallib-0.9.4.5.tar.gz
- Upload date:
- Size: 272.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1444757a4e5dd6157f17ccc09a36affa972521e6fbf79ea49c44046e55b027dd
|
|
| MD5 |
b9901aac4e63af729f6beb501cd92041
|
|
| BLAKE2b-256 |
ecc93a502351f36fe6e708e5de42dabfdf673ca30dd6d7bce9355b26a3697d72
|
File details
Details for the file jax_metallib-0.9.4.5-cp313-cp313-macosx_26_0_arm64.whl.
File metadata
- Download URL: jax_metallib-0.9.4.5-cp313-cp313-macosx_26_0_arm64.whl
- Upload date:
- Size: 6.8 MB
- Tags: CPython 3.13, macOS 26.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ddd4aa76def4eb00630de5cf7bc8d13e33c7c2d4939b209e482b66ff66f62e80
|
|
| MD5 |
65d325576b823bd0bedab17af69db17e
|
|
| BLAKE2b-256 |
e8214fdf355e73c1f00449a1e44b84b81846eb3b95ba43c7c1431d5c19ef083d
|