JAX backend for Apple Metal Performance Shaders (MPS)
Project description
applejax
A JAX backend for Apple Metal Performance Shaders (MPS), enabling GPU-accelerated JAX computations on Apple Silicon.
Fork of tillahoffmann/jax-mps with full linear algebra, complex number support, comprehensive scatter/gather handling, and 2000+ tests.
Quick Start
pip install applejax
The plugin registers itself with JAX automatically. Set JAX_PLATFORMS=mps to select it explicitly.
Requires macOS on Apple Silicon, Python >= 3.11, and jaxlib 0.9.x.
Performance
applejax achieves a modest 3x speed-up over the CPU backend when training a simple ResNet18 model on CIFAR-10 using an M4 MacBook Air.
$ JAX_PLATFORMS=cpu uv run examples/resnet/main.py --steps=30
Time per step (second half): 3.041
$ JAX_PLATFORMS=mps uv run examples/resnet/main.py --steps=30
Time per step (second half): 0.991
What Works
All JAX operations are supported, verified across 2000+ tests covering all major categories:
| Category | Status | Notes |
|---|---|---|
| Element-wise math (unary/binary) | Full | sin, cos, exp, log, erf, gelu, etc. |
| Reductions | Full | sum, prod, max, min, argmax, cumsum, cummax, logsumexp |
| Matmul / dot products | Full | float16, bfloat16, float32, int, complex64 |
| Convolution | Full | 1D, 2D, depthwise, transposed, dilated |
| Pooling | Full | max, avg, min pool with gradients |
| FFT | Full | fft, ifft, rfft, irfft, fft2, ifft2 |
| Sorting | Full | sort, argsort, top_k, unique, searchsorted |
| Shape ops | Full | reshape, transpose, pad, gather, scatter, concatenate |
| Bitwise ops | Full | and, or, xor, not, shifts, population_count, clz |
| Random | Full | normal, uniform, bernoulli, categorical, poisson, gamma, beta, etc. |
| Type conversions | Full | float16/bfloat16/float32/int8-64/bool/complex64, reduce_precision |
| Control flow | Full | cond, switch, while_loop, fori_loop, scan, associative_scan |
| Autodiff | Full | grad, jacobian, hessian, HVP, checkpoint, custom_jvp/vjp |
| Transforms | Full | jit, vmap, pmap (single device) |
| Linear algebra | Full | See below |
| Complex numbers | Nearly full | Arithmetic, matmul, FFT, all linalg. No complex sort/conv |
| scipy.special | Full | erf, gammaln, digamma, betaln, logit, expit, etc. |
| scipy.signal | Full | convolve, correlate, fftconvolve |
| scipy.stats | Full | norm.logpdf, norm.cdf, norm.ppf |
| scipy.ndimage | Full | map_coordinates |
Linear Algebra
All operations work for both real (float32) and complex (complex64) inputs:
| Operation | Function | Backend |
|---|---|---|
| Solve | jnp.linalg.solve |
MPS Graph |
| Inverse | jnp.linalg.inv |
MPS Graph |
| Cholesky | jnp.linalg.cholesky |
MPS Graph (real), Accelerate cpotrf_ (complex) |
| Triangular solve | scipy.linalg.solve_triangular |
MPS Graph (real), Accelerate ctrsm_ (complex) |
| QR | jnp.linalg.qr |
Accelerate sgeqrf_/cgeqrf_ |
| SVD | jnp.linalg.svd |
Accelerate sgesdd_/cgesdd_ |
| Eigendecomposition (symmetric) | jnp.linalg.eigh |
Accelerate ssyevd_/cheevd_ |
| Eigendecomposition (general) | jnp.linalg.eig |
Accelerate sgeev_/cgeev_ |
| Schur | scipy.linalg.schur |
Accelerate sgees_/cgees_ |
| Matrix square root | scipy.linalg.sqrtm |
Via Schur |
| Matrix exponential | scipy.linalg.expm |
Via Schur + solve |
| LU | scipy.linalg.lu |
Via JAX primitives |
| Determinant, norm, cond, rank, pinv, lstsq | All | Via SVD/QR/solve |
ML Framework Compatibility
Tested successfully with:
- Flax NNX — training loops with optimizers
- NumPyro — MCMC inference (NUTS sampler)
- Optax — all standard optimizers
- Equinox — neural network modules
Transformer components work end-to-end: multi-head attention, RoPE, RMSNorm, SwiGLU, causal masking.
Known Limitations
These are Metal/MPS hardware constraints, not bugs in applejax:
| Limitation | Impact | Workaround |
|---|---|---|
| No float64 | Metal GPUs only support 32-bit floats | Use float32 (default). jax.config.update("jax_enable_x64", True) will not work. |
| No complex sort | jnp.sort on complex arrays crashes MPS |
Sort real/imag parts separately |
| No complex convolution | MPS conv ops don't support complex types | Decompose into real/imag convolutions manually |
No jax.debug.print |
Debug printing inside JIT not supported | Use jax.debug.callback or print outside JIT |
| Linalg inside control flow | QR, SVD, eigh, eig inside scan/fori_loop/while_loop crash (Accelerate-backed ops run on CPU, incompatible with MPS Graph control flow) |
Restructure code to call these ops outside control flow |
| No buffer donation | Memory optimization hint is ignored (warning only) | No impact on correctness, minor memory overhead |
scipy.linalg.polar(method='qdwh') crashes |
QDWH algorithm promotes to float64 internally | Use polar(method='svd') instead |
| Zero-size arrays | MPS doesn't support empty tensors | Avoid zero-dimension operations |
Architecture
This project implements a PJRT plugin to offload evaluation of JAX expressions to a Metal Performance Shaders Graph. The evaluation proceeds in several stages:
- JAX lowers the program to StableHLO, a set of high-level operations for ML.
- The plugin parses the StableHLO representation and builds the corresponding MPS graph. The graph is cached to avoid re-construction on repeated invocations.
- The MPS graph is executed on the GPU. Operations not natively supported by MPS (e.g., linear algebra decompositions) run on CPU via Apple's Accelerate framework using a "native handler" mechanism.
Operation Implementations
| Layer | Examples | Count |
|---|---|---|
| StableHLO graph ops | add, matmul, conv, reduce, sort, FFT, gather, scatter | 71 |
| CHLO graph ops | erf, top_k, acos, sinh, erf_inv, next_after | 12 |
| Native handlers (CPU via Accelerate) | cholesky, triangular_solve, eigh, SVD, eig, Schur | 6 |
| Python lowering rules | eigh, svd, eig, schur → custom_call → native handler | 4 |
Building from Source
- Install build tools and build LLVM/MLIR & StableHLO (one-time, ~30 minutes):
brew install cmake ninja
./scripts/setup_deps.sh
- Build and install:
uv pip install -e .
- Install dev dependencies (test runner, linters, ML frameworks used in examples) and run tests:
uv sync --all-groups
uv run pytest
Version Pinning
applejax is built against the StableHLO bytecode format matching jaxlib 0.9.x. The setup_deps.sh script pins LLVM and StableHLO to specific commits from the XLA version used by jaxlib 0.9.0.
Runtime compatibility: Any jaxlib 0.9.x release should work with a built binary — the bytecode format is stable within a minor version series. The plugin will warn (not error) if the minor version doesn't match.
Updating for a new jaxlib release: Trace the dependency chain:
# 1. Find XLA commit used by jaxlib
curl -s https://raw.githubusercontent.com/jax-ml/jax/jax-v0.9.0/third_party/xla/revision.bzl
# 2. Find LLVM and StableHLO commits used by that XLA version
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/llvm/workspace.bzl
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/stablehlo/workspace.bzl
Then update STABLEHLO_COMMIT and LLVM_COMMIT_OVERRIDE in setup_deps.sh.
Project Structure
applejax/
├── CMakeLists.txt
├── src/
│ ├── jax_plugins/mps/ # Python plugin: registration + lowering rules
│ ├── pjrt_plugin/ # C++ PJRT implementation
│ │ ├── pjrt_api.cc # PJRT C API entry point
│ │ ├── mps_client.h/mm # Metal client management
│ │ ├── mps_executable.h/mm # StableHLO compilation & execution
│ │ └── ops/ # Operation implementations
│ │ ├── unary_ops.mm # Element-wise unary operations
│ │ ├── binary_ops.mm # Binary operations, dot products, matmul
│ │ ├── bitwise_ops.mm # Bitwise operations
│ │ ├── shape_ops.mm # Gather, scatter, reshape, pad, etc.
│ │ ├── reduction_ops.mm # Reduce, reduce_window, scan
│ │ ├── linalg_ops.mm # Cholesky, QR, SVD, eigh, eig, Schur
│ │ ├── convolution_ops.mm # Convolution
│ │ ├── control_flow_ops.mm # cond, while, scan
│ │ ├── fft_ops.mm # FFT operations
│ │ ├── sort_ops.mm # Sort and top-k
│ │ ├── tensor_creation_ops.mm # Constants, iota
│ │ └── registry.h # Op registration macros
│ └── proto/ # Protobuf definitions
├── tests/
│ ├── test_ops.py # Main test file (parameterized)
│ └── configs/ # Test configurations by category
└── mps_ops/ # Reference docs for MPS Graph methods
Benchmarks
uv run pytest -m benchmark --benchmark-only
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file applejax-0.9.7-cp313-cp313-macosx_14_0_arm64.whl.
File metadata
- Download URL: applejax-0.9.7-cp313-cp313-macosx_14_0_arm64.whl
- Upload date:
- Size: 7.1 MB
- Tags: CPython 3.13, macOS 14.0+ ARM64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
82f8936991f6bfe0c1e9584051bcb7ee784a1d7c369927b9e056edabdf2de8a2
|
|
| MD5 |
a79102345b43c5eb04a17360e0b26090
|
|
| BLAKE2b-256 |
29548ba1e17af2ef175c07ee8288c849626dd7c43e94ab4eb9b42b22a3053d6f
|
Provenance
The following attestation bundles were made for applejax-0.9.7-cp313-cp313-macosx_14_0_arm64.whl:
Publisher:
build.yml on danielpcox/applejax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
applejax-0.9.7-cp313-cp313-macosx_14_0_arm64.whl -
Subject digest:
82f8936991f6bfe0c1e9584051bcb7ee784a1d7c369927b9e056edabdf2de8a2 - Sigstore transparency entry: 1082401136
- Sigstore integration time:
-
Permalink:
danielpcox/applejax@7b3bad1abd01cf17beb9d65f3e6ee25647592d03 -
Branch / Tag:
refs/tags/v0.9.7 - Owner: https://github.com/danielpcox
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
build.yml@7b3bad1abd01cf17beb9d65f3e6ee25647592d03 -
Trigger Event:
push
-
Statement type: