Skip to main content

JAX backend for Apple M series of chips

Project description

jax-metallib

A PJRT plugin that enables JAX to run on Apple Metal (MPS) GPUs on Apple Silicon. It compiles StableHLO IR to Metal shaders via MPSGraph, giving JAX programs GPU acceleration on M-series Macs.

Status: Alpha (v0.9.0) -- API and op coverage are evolving.

Requirements

  • macOS 13+ on Apple Silicon (M1 / M2 / M3 / M4)
  • Python 3.11+
  • jax and jaxlib 0.9.x

Install

pip install jax-metallib

Verify:

JAX_PLATFORMS=mps python -c "import jax; print(jax.devices())"
# [MpsDevice(id=0)]

Build from source

brew install cmake ninja
git clone https://github.com/erfanzar/jax-metallib.git
cd jax-metallib
uv sync --all-groups
uv pip install -e .       # auto-bootstraps native deps on first build (~30 min)

To skip the automatic dependency bootstrap (if you manage deps yourself):

CMAKE_ARGS="-DJAX_SILICON_AUTO_SETUP_DEPS=OFF" uv pip install -e .

Native dependencies

scripts/setup_deps.sh fetches and builds:

Dependency Version / Commit
LLVM + MLIR XLA pin bb760b0
StableHLO 127d2f2
Abseil 20250127.0
Protobuf 29.3

These are installed to ~/.local/jax-silicon-deps by default.

Usage

The plugin registers as the mps platform in JAX:

import jax
import jax.numpy as jnp

# With JAX_PLATFORMS=mps (or setting jax.config)
x = jnp.ones((1024, 1024))
y = x @ x  # runs on Metal GPU

Environment variables

Variable Description
JAX_PLATFORMS=mps Select the MPS backend
JAX_SILICON_LIBRARY_PATH Override path to libpjrt_plugin_silicon.dylib
JAX_MPS_LIBRARY_PATH Legacy alias for the above
MPS_LOG_LEVEL=0..3 Logging verbosity (0=error, 1=warn, 2=info, 3=debug)

Supported operations

100+ StableHLO operations are implemented across these categories:

Category Examples
Unary tanh, exp, log, sin, cos, sqrt, erf, abs, sign
Binary add, subtract, multiply, divide, dot, comparisons
Reductions reduce_sum, reduce_max, reduce_min, argmax, argmin
Shape reshape, transpose, slice, pad, concatenate, gather, scatter
Convolution conv_general_dilated with arbitrary padding/dilation
Linear algebra matmul, cholesky, qr, svd, triangular_solve
FFT fft, rfft, ifft, irfft
Random Threefry / Philox RNG
Sort sort, top_k
Control flow cond (if/else), while_loop, scan
Bitwise and, or, xor, shift_left, shift_right

Encountering an unsupported op prints a diagnostic with a link to file a feature request.

Testing

uv run pytest                   # compare CPU vs MPS (default)
JAX_TEST_MODE=mps uv run pytest # MPS only
JAX_TEST_MODE=cpu uv run pytest # CPU only

The test suite covers value correctness, gradient accuracy, and includes integration tests with Flax and NumPyro.

How it works

JAX Python code
      |
StableHLO IR (MLIR)
      |
stablehlo_parser.mm   -- parses IR, looks up ops in the registry
      |
MPSGraph operations   -- builds a Metal compute graph
      |
Metal command buffer  -- compiled & dispatched to GPU
      |
Device memory result

The PJRT C API (pjrt_api.cc) exposes client, device, buffer, and executable abstractions that JAX expects. Each StableHLO op is registered in src/pjrt_plugin/ops/ and mapped to the corresponding MPSGraph method.

Repository layout

src/
  jax_plugins/silicon/    Python entrypoint (plugin registration)
  pjrt_plugin/            C++/Obj-C++ PJRT backend
    ops/                  Op implementations (~100+ ops)
    stablehlo_parser.mm   StableHLO IR -> MPSGraph compiler
    mps_client.mm         Metal device & command queue management
    mps_executable.mm     Executable compilation & dispatch
tests/
  test_ops.py             Parametrized op tests
  configs/                Per-category test configurations
scripts/
  setup_deps.sh           One-time native dependency bootstrap
  release.sh              Release automation

Contributing

See CONTRIBUTING.md. The short version:

  1. brew install cmake ninja && ./scripts/setup_deps.sh && uv sync --all-groups
  2. uv pip install -e . && pre-commit install
  3. Add your op in src/pjrt_plugin/ops/, add a test in tests/configs/, rebuild, and run uv run pytest.

License

Apache-2.0. See LICENSE.

This project is a derivative of tillahoffmann/jax-mps.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_metallib-0.9.1.tar.gz (161.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_metallib-0.9.1-cp313-cp313-macosx_26_0_arm64.whl (6.7 MB view details)

Uploaded CPython 3.13macOS 26.0+ ARM64

File details

Details for the file jax_metallib-0.9.1.tar.gz.

File metadata

  • Download URL: jax_metallib-0.9.1.tar.gz
  • Upload date:
  • Size: 161.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for jax_metallib-0.9.1.tar.gz
Algorithm Hash digest
SHA256 09ce4d944d0ffa212c12ca8742665d776de333e4f549d8655acda9585e55aed6
MD5 3604f715c706e422b40014a9015f879d
BLAKE2b-256 5a79f229d52bf17bdc0108fddd6caf527eed78a242fe5dc14e23b5a23361b01e

See more details on using hashes here.

File details

Details for the file jax_metallib-0.9.1-cp313-cp313-macosx_26_0_arm64.whl.

File metadata

  • Download URL: jax_metallib-0.9.1-cp313-cp313-macosx_26_0_arm64.whl
  • Upload date:
  • Size: 6.7 MB
  • Tags: CPython 3.13, macOS 26.0+ ARM64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for jax_metallib-0.9.1-cp313-cp313-macosx_26_0_arm64.whl
Algorithm Hash digest
SHA256 816bf346b3f1b12e5f6aee6156d80e4da1ae2413203e4a032806b8379ef01d51
MD5 cf9aa45ccb4a409b789b951acb2e7274
BLAKE2b-256 2980be7ed0aebc963f9a51e9e3ef45021780b68a2af7b3aa2e2a79669a409ec7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page