Skip to main content

JAX backend for Apple Metal Performance Shaders (MPS)

Project description

jax-mps GitHub Action Badge PyPI

A JAX backend for Apple Silicon using MLX, enabling GPU-accelerated JAX computations on Mac.

Example

jax-mps achieves a ~3.7x speed-up over the CPU backend when training a simple ResNet18 model on CIFAR-10 using an M4 MacBook Air.

$ JAX_PLATFORMS=cpu uv run examples/resnet/main.py --steps=30
loss = 0.029: 100%|██████████| 30/30 [01:41<00:00,  3.37s/it]
Final training loss: 0.029
Time per step (second half): 3.437

$ JAX_PLATFORMS=mps uv run examples/resnet/main.py --steps=30
WARNING:...:jax._src.xla_bridge:905: Platform 'mps' is experimental and not all JAX functionality may be correctly supported!
loss = 0.029: 100%|██████████| 30/30 [00:27<00:00,  1.07it/s]
Final training loss: 0.029
Time per step (second half): 0.928

Installation

jax-mps requires macOS on Apple Silicon and Python 3.13. Install it with pip:

pip install jax-mps

The plugin registers itself with JAX automatically and is enabled by default. Set JAX_PLATFORMS=mps to select the MPS backend explicitly.

jax-mps is built against the StableHLO bytecode format matching jaxlib 0.9.x. Using a different jaxlib version will likely cause deserialization failures at JIT compile time. See Version Pinning for details.

Architecture

This project implements a PJRT plugin that uses MLX to execute JAX programs on Apple Silicon GPUs. The evaluation proceeds in several stages:

  1. The JAX program is lowered to StableHLO, a set of high-level operations for machine learning applications.
  2. The plugin parses the StableHLO representation and maps operations to MLX equivalents. Compiled programs are cached to avoid re-parsing on repeated invocations.
  3. The MLX operations are executed on the GPU and results are returned to the caller.

Building

  1. Install build tools and build and install LLVM/MLIR & StableHLO. This is a one-time setup and takes about 30 minutes. See the setup_deps.sh script for further options, such as forced re-installation, installation location, etc. The script pins LLVM and StableHLO to specific commits matching jaxlib 0.9.0 for bytecode compatibility (see the section on Version Pinning) for details.
$ brew install cmake ninja
$ ./scripts/setup_deps.sh
  1. Build the plugin and install it as a Python package. This step should be fast, and MUST be repeated for all changes to C++ files.
$ uv pip install -e .

Version Pinning

The script pins LLVM and StableHLO to specific commits matching jaxlib 0.9.0 for bytecode compatibility. To update these versions for a different jaxlib release, trace the dependency chain:

# 1. Find XLA commit used by jaxlib
curl -s https://raw.githubusercontent.com/jax-ml/jax/jax-v0.9.0/third_party/xla/revision.bzl
# → XLA_COMMIT = "bb760b04..."

# 2. Find LLVM commit used by that XLA version
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/llvm/workspace.bzl
# → LLVM_COMMIT = "f6d0a512..."

# 3. Find StableHLO commit used by that XLA version
curl -s https://raw.githubusercontent.com/openxla/xla/<XLA_COMMIT>/third_party/stablehlo/workspace.bzl
# → STABLEHLO_COMMIT = "127d2f23..."

Then update the STABLEHLO_COMMIT and LLVM_COMMIT_OVERRIDE variables in setup_deps.sh.

Project Structure

jax-mps/
├── CMakeLists.txt
├── src/
│   ├── jax_plugins/mps/         # Python JAX plugin
│   ├── pjrt_plugin/             # C++ PJRT implementation
│   │   ├── pjrt_api.cc          # PJRT C API entry point
│   │   ├── mps_client.h/mm      # Metal client management
│   │   ├── mlx_executable.h/mm  # StableHLO compilation & MLX execution
│   │   └── ops/                 # Operation registry
│   └── proto/                   # Protobuf definitions
└── tests/

How It Works

PJRT Plugin

PJRT (Portable JAX Runtime) is JAX's abstraction for hardware backends. The plugin implements:

  • PJRT_Client_Create - Initialize Metal device
  • PJRT_Client_Compile - Parse StableHLO and build MLX operation graph
  • PJRT_Client_BufferFromHostBuffer - Transfer data to GPU
  • PJRT_LoadedExecutable_Execute - Run computation on GPU

MLX Execution

StableHLO operations are mapped to MLX equivalents, e.g.:

  • stablehlo.addmlx::core::add()
  • stablehlo.dot_generalmlx::core::matmul()
  • stablehlo.convolutionmlx::core::conv_general()
  • stablehlo.reducemlx::core::sum/max/min/prod()

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_mps-0.9.10.dev357-cp313-cp313-macosx_14_0_arm64.whl (43.6 MB view details)

Uploaded CPython 3.13macOS 14.0+ ARM64

File details

Details for the file jax_mps-0.9.10.dev357-cp313-cp313-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for jax_mps-0.9.10.dev357-cp313-cp313-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 31028a86d7392b67d368a76e1ad0a2f8f8caabd2d103f977a03b08b3e2269457
MD5 2687f04cfe329cb96a50c5cbdb263698
BLAKE2b-256 a742eb7acf05444051b8491100ebc7549ea05e388a5d5a57c2baf3bd7186c5c3

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_mps-0.9.10.dev357-cp313-cp313-macosx_14_0_arm64.whl:

Publisher: build.yml on tillahoffmann/jax-mps

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page