JAX backend for Apple Metal Performance Shaders (MPS)
Project description
jax-mps 
A JAX backend for Apple Metal Performance Shaders (MPS), enabling GPU-accelerated JAX computations on Apple Silicon.
Example
jax-mps achieves a modest 3x speed-up over the CPU backend when training a simple ResNet18 model on CIFAR-10 using an M4 MacBook Air.
$ JAX_PLATFORMS=cpu uv run examples/resnet/main.py --steps=30
loss = 0.029: 100%|██████████| 30/30 [01:29<00:00, 2.99s/it]
Final training loss: 0.029
Time per step (second half): 3.041
$ JAX_PLATFORMS=mps uv run examples/resnet/main.py --steps=30
WARNING:2026-01-26 17:32:53,989:jax._src.xla_bridge:905: Platform 'mps' is experimental and not all JAX functionality may be correctly supported!
loss = 0.028: 100%|██████████| 30/30 [00:30<00:00, 1.03s/it]
Final training loss: 0.028
Time per step (second half): 0.991
Architecture
This project implements a PJRT plugin to offload evaluation of JAX expressions to a Metal Performance Shaders Graph. The evaluation proceeds in several stages:
- The JAX program is lowered to StableHLO, a set of high-level operations for machine learning applications.
- The plugin parses the StableHLO representation of the program and builds the corresponding MPS graph. The graph is cached to avoid re-construction on invocation of the same program, e.g., repeated training steps.
- The MPS graph is executed, using native MPS operations where possible, and the results are returned to the caller.
Building
- Install build tools and build and install LLVM/MLIR & StableHLO. This is a one-time setup and takes about 30 minutes. See the
setup_deps.shscript for further options, such as forced re-installation, installation location, etc. The script pins LLVM and StableHLO to specific commits matching jaxlib 0.9.0 for bytecode compatibility (see the section on Version Pinning) for details.
$ brew install cmake ninja
$ ./scripts/setup_deps.sh
- Build the plugin and install it as a Python package. This step should be fast, and MUST be repeated for all changes to C++ files.
$ uv pip install -e .
Version Pinning
The script pins LLVM and StableHLO to specific commits matching jaxlib 0.9.0 for bytecode compatibility. To update these versions for a different jaxlib release, trace the dependency chain:
# 1. Find XLA commit used by jaxlib
curl -s https://raw.githubusercontent.com/jax-ml/jax/jax-v0.9.0/third_party/xla/revision.bzl
# → XLA_COMMIT = "bb760b04..."
# 2. Find LLVM commit used by that XLA version
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/llvm/workspace.bzl
# → LLVM_COMMIT = "f6d0a512..."
# 3. Find StableHLO commit used by that XLA version
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/stablehlo/workspace.bzl
# → STABLEHLO_COMMIT = "127d2f23..."
Then update the STABLEHLO_COMMIT and LLVM_COMMIT_OVERRIDE variables in setup_deps.sh.
Project Structure
jax-mps/
├── CMakeLists.txt
├── src/
│ ├── jax_plugins/mps/ # Python JAX plugin
│ ├── pjrt_plugin/ # C++ PJRT implementation
│ │ ├── pjrt_api.cc # PJRT C API entry point
│ │ ├── mps_client.h/mm # Metal client management
│ │ ├── mps_executable.h/mm # StableHLO compilation & execution
│ │ └── ops/ # Operation implementations
│ └── proto/ # Protobuf definitions
└── tests/
How It Works
PJRT Plugin
PJRT (Portable JAX Runtime) is JAX's abstraction for hardware backends. The plugin implements:
PJRT_Client_Create- Initialize Metal devicePJRT_Client_Compile- Parse HLO and prepare MPSGraphPJRT_Client_BufferFromHostBuffer- Transfer data to GPUPJRT_LoadedExecutable_Execute- Run computation on GPU
MPSGraph Execution
Operations are mapped to MPSGraph equivalents, e.g.,:
add→additionWithPrimaryTensor:secondaryTensor:dot→matrixMultiplicationWithPrimaryTensor:secondaryTensor:tanh→tanhWithTensor:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_mps-0.9.3-cp313-cp313-macosx_14_0_arm64.whl.
File metadata
- Download URL: jax_mps-0.9.3-cp313-cp313-macosx_14_0_arm64.whl
- Upload date:
- Size: 7.0 MB
- Tags: CPython 3.13, macOS 14.0+ ARM64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5976767276eb3ac30e6423e0f1561f61ca47ce2f3c9ffce8a552f60ff772400c
|
|
| MD5 |
ec23fa2d2812188e7026c624594968a4
|
|
| BLAKE2b-256 |
fdcbdf84ed91e4a210ded15a7bceb3206579e5569c3bb54396b516c32be33504
|
Provenance
The following attestation bundles were made for jax_mps-0.9.3-cp313-cp313-macosx_14_0_arm64.whl:
Publisher:
build.yml on tillahoffmann/jax-mps
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jax_mps-0.9.3-cp313-cp313-macosx_14_0_arm64.whl -
Subject digest:
5976767276eb3ac30e6423e0f1561f61ca47ce2f3c9ffce8a552f60ff772400c - Sigstore transparency entry: 869240832
- Sigstore integration time:
-
Permalink:
tillahoffmann/jax-mps@bc517a2882a3b9e7385271ad9c72ae74acb87864 -
Branch / Tag:
refs/tags/v0.9.3 - Owner: https://github.com/tillahoffmann
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
build.yml@bc517a2882a3b9e7385271ad9c72ae74acb87864 -
Trigger Event:
push
-
Statement type: