Skip to main content

JAX backend for Apple M series of chips

Project description

jax-metallib

A PJRT plugin that enables JAX to run on Apple Metal (MPS) GPUs on Apple Silicon. It compiles StableHLO IR to Metal shaders via MPSGraph, giving JAX programs GPU acceleration on M-series Macs.

Status: Alpha (v0.9.1) -- API and op coverage are evolving.

Requirements

  • macOS 13+ on Apple Silicon (M1 / M2 / M3 / M4)
  • Python 3.11+
  • jax and jaxlib 0.9.x

Install

pip install jax-metallib

Verify:

JAX_PLATFORMS=mps python -c "import jax; print(jax.devices())"
# [MpsDevice(id=0)]

Build from source

brew install cmake ninja
git clone https://github.com/erfanzar/jax-metallib.git
cd jax-metallib
uv sync --all-groups
uv pip install -e .       # auto-bootstraps native deps on first build (~30 min)

To skip the automatic dependency bootstrap (if you manage deps yourself):

CMAKE_ARGS="-DJAX_SILICON_AUTO_SETUP_DEPS=OFF" uv pip install -e .

Native dependencies

scripts/setup_deps.sh fetches and builds:

Dependency Version / Commit
LLVM + MLIR XLA pin bb760b0
StableHLO 127d2f2
Abseil 20250127.0
Protobuf 29.3

These are installed to ~/.local/jax-silicon-deps by default.

Usage

The plugin registers as the mps platform in JAX:

import jax
import jax.numpy as jnp

# With JAX_PLATFORMS=mps (or setting jax.config)
x = jnp.ones((1024, 1024))
y = x @ x  # runs on Metal GPU

Environment variables

Variable Description
JAX_PLATFORMS=mps Select the MPS backend
JAX_SILICON_LIBRARY_PATH Override path to libpjrt_plugin_silicon.dylib
JAX_MPS_LIBRARY_PATH Legacy alias for the above
MPS_LOG_LEVEL=0..3 Logging verbosity (0=error, 1=warn, 2=info, 3=debug)

Supported operations

100+ StableHLO operations are implemented across these categories:

Category Examples
Unary tanh, exp, log, sin, cos, sqrt, erf, abs, sign
Binary add, subtract, multiply, divide, dot, comparisons
Reductions reduce_sum, reduce_max, reduce_min, argmax, argmin
Shape reshape, transpose, slice, pad, concatenate, gather, scatter
Convolution conv_general_dilated with arbitrary padding/dilation
Linear algebra matmul, cholesky, qr, svd, triangular_solve
FFT fft, rfft, ifft, irfft
Random Threefry / Philox RNG
Sort sort, top_k
Control flow cond (if/else), while_loop, scan
Bitwise and, or, xor, shift_left, shift_right

Encountering an unsupported op prints a diagnostic with a link to file a feature request.

Testing

uv run pytest                   # compare CPU vs MPS (default)
JAX_TEST_MODE=mps uv run pytest # MPS only
JAX_TEST_MODE=cpu uv run pytest # CPU only

The test suite covers value correctness, gradient accuracy, and includes integration tests with Flax and NumPyro.

How it works

JAX Python code
      |
StableHLO IR (MLIR)
      |
stablehlo_parser.mm   -- parses IR, looks up ops in the registry
      |
MPSGraph operations   -- builds a Metal compute graph
      |
Metal command buffer  -- compiled & dispatched to GPU
      |
Device memory result

The PJRT C API (pjrt_api.cc) exposes client, device, buffer, and executable abstractions that JAX expects. Each StableHLO op is registered in src/pjrt_plugin/ops/ and mapped to the corresponding MPSGraph method.

Repository layout

src/
  jax_plugins/silicon/    Python entrypoint (plugin registration)
  pjrt_plugin/            C++/Obj-C++ PJRT backend
    ops/                  Op implementations (~100+ ops)
    stablehlo_parser.mm   StableHLO IR -> MPSGraph compiler
    mps_client.mm         Metal device & command queue management
    mps_executable.mm     Executable compilation & dispatch
tests/
  test_ops.py             Parametrized op tests
  configs/                Per-category test configurations
scripts/
  setup_deps.sh           One-time native dependency bootstrap
  release.sh              Release automation

Contributing

See CONTRIBUTING.md. The short version:

  1. brew install cmake ninja && ./scripts/setup_deps.sh && uv sync --all-groups
  2. uv pip install -e . && pre-commit install
  3. Add your op in src/pjrt_plugin/ops/, add a test in tests/configs/, rebuild, and run uv run pytest.

License

Apache-2.0. See LICENSE.

This project is a derivative of tillahoffmann/jax-mps.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_metallib-0.9.4.tar.gz (257.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_metallib-0.9.4-cp313-cp313-macosx_26_0_arm64.whl (6.8 MB view details)

Uploaded CPython 3.13macOS 26.0+ ARM64

File details

Details for the file jax_metallib-0.9.4.tar.gz.

File metadata

  • Download URL: jax_metallib-0.9.4.tar.gz
  • Upload date:
  • Size: 257.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for jax_metallib-0.9.4.tar.gz
Algorithm Hash digest
SHA256 1a851486867c291bfa02baa42b02f2c44286a8e475f557636b4d5b61bc1fc4b5
MD5 aa8dd868fc502e424e39a7578988e6b4
BLAKE2b-256 4d1b0a91b264690cb7b2caa4673099ae19a4e8e152385557fc7a5e962abde673

See more details on using hashes here.

File details

Details for the file jax_metallib-0.9.4-cp313-cp313-macosx_26_0_arm64.whl.

File metadata

  • Download URL: jax_metallib-0.9.4-cp313-cp313-macosx_26_0_arm64.whl
  • Upload date:
  • Size: 6.8 MB
  • Tags: CPython 3.13, macOS 26.0+ ARM64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for jax_metallib-0.9.4-cp313-cp313-macosx_26_0_arm64.whl
Algorithm Hash digest
SHA256 498c81fd07f0718cb0dcd7a3224af3ade973cf00464c5109f00bd17481f22f03
MD5 6ce9610cb195ba43e410e1098276e493
BLAKE2b-256 75a31e00232b1bc3e5c777af1389ca5e5b8799f5de309b6b5aa0518be6c17fb3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page