Skip to main content

JAX backend for Apple M series of chips

Project description

JAX
jax-metallib
Run JAX on Apple Metal GPUs

License PyPI Python Platform


jax-metallib is a PJRT plugin that enables JAX to run on Apple Metal GPUs. It compiles StableHLO IR to Metal compute kernels via MPSGraph, giving JAX programs native GPU acceleration on M-series Macs — no code changes required.

Highlights

  • 120+ StableHLO ops — from basic arithmetic through convolutions, FFTs, linear algebra (Cholesky, QR, SVD), sorting, and control flow
  • Drop-in acceleration — set JAX_PLATFORMS=mps and existing JAX code runs on the Metal GPU
  • Full gradient supportjax.grad, jax.value_and_grad, and higher-order derivatives work out of the box
  • JIT kernel fusion — consecutive elementwise ops are fused into single Metal Shading Language kernels at runtime
  • Native MPS kernels — performance-critical operations (Cholesky decomposition, triangular solve) use MPS native kernels directly, bypassing the graph compiler
  • Executable serialization — compiled programs can be serialized and deserialized for AOT compilation and jax.jit caching
  • Framework integration — tested with Flax (NNX) and NumPyro

Quick Start

Install

pip install jax-metallib

Verify

JAX_PLATFORMS=mps python -c "import jax; print(jax.devices())"

Run

import jax
import jax.numpy as jnp

x = jnp.ones((1024, 1024))
y = x @ x

f = jax.jit(jax.grad(lambda x: jnp.sum(jnp.tanh(x))))
grads = f(x)

Requirements

Requirement Version
macOS 13.0+ (Ventura or later)
Hardware Apple Silicon (M1 / M2 / M3 / M4 / M5)
Python 3.11, 3.12, or 3.13
JAX 0.9.x
jaxlib 0.9.x

Build from Source

Building from source compiles the native Metal plugin (~6 MB shared library). The first build automatically bootstraps LLVM/MLIR and StableHLO dependencies, which takes about 30 minutes.

brew install cmake ninja

git clone https://github.com/erfanzar/jax-metallib.git
cd jax-metallib
uv sync --all-groups
uv pip install -e .

To skip the automatic dependency bootstrap (if you manage LLVM/MLIR yourself):

CMAKE_ARGS="-DJAX_METALLIB_AUTO_SETUP_DEPS=OFF" uv pip install -e .

Native Dependencies

The bootstrap script (scripts/setup_deps.sh) fetches and builds the following into ~/.local/jax-metallib-deps:

Dependency Version Purpose
LLVM + MLIR XLA pin bb760b0 MLIR infrastructure, StableHLO dialect support
StableHLO 127d2f2 IR parsing, serialization, dialect definitions
Abseil 20250127.0 C++ utilities (strings, status, synchronization)
Protobuf 29.3 Device assignment protocol buffer serialization

Supported Operations

122 operations are registered across StableHLO, CHLO, and MHLO dialects:

Unary operations (42 ops)

abs cbrt ceil cosine count_leading_zeros exponential exponential_minus_one erf floor imag is_finite log log_plus_one logistic negate real round_nearest_even rsqrt sign sine sqrt tan tanh

CHLO: acos acosh asin asinh atanh bessel_i1e conj cosh digamma erf_inv erfc is_inf is_neg_inf is_pos_inf lgamma sinh square

Binary operations (15 ops)

add atan2 clamp compare divide dot dot_general maximum minimum multiply power remainder select subtract

CHLO: next_after

Shape & indexing (18 ops)

bitcast_convert broadcast broadcast_in_dim concatenate convert dynamic_broadcast_in_dim dynamic_reshape dynamic_slice dynamic_update_slice gather get_dimension_size pad reshape reverse scatter set_dimension_size slice transpose

Reductions (7 ops)

reduce (sum, product, max, min, and, or, argmax, argmin) reduce_window select_and_scatter batch_norm_inference batch_norm_training batch_norm_grad return

Convolution

convolution — full conv_general_dilated with arbitrary padding, dilation, strides, feature grouping, and batch grouping

Linear algebra

cholesky (native MPS kernel) triangular_solve (native MPS kernel)

Other categories
Category Operations
Bitwise and or xor not shift_left shift_right_logical shift_right_arithmetic popcnt
FFT fft (FFT, RFFT, IFFT, IRFFT)
Sort sort top_k
Random rng rng_bit_generator (Threefry / Philox)
Tensor creation constant iota
Control flow while case (if/else) custom_call
Collective all_reduce all_gather reduce_scatter collective_permute
Higher-order map complex

Encountering an unsupported op prints a diagnostic with a direct link to file a feature request.

Architecture

JAX Python program
        │
        ▼
  StableHLO IR (MLIR bytecode)          ← jax.jit compiles Python to StableHLO
        │
        ▼
  PJRT C API layer                      ← pjrt_api.cc exposes client/device/buffer/exec
        │
        ▼
  StableHLO Parser                      ← deserializes + inlines MLIR modules
        │
        ▼
  Execution Plan Builder                ← walks ops, groups into MPSGraph segments
        │                                  + native MPS kernel steps
        ├──► MPSGraph segments           ← op handlers build compute graphs
        │         │
        │         ▼
        │    Metal command buffer        ← compiled & dispatched to GPU
        │
        └──► Native MPS kernels          ← direct kernel dispatch (Cholesky, etc.)
                  │
                  ▼
         Device memory (MTLBuffer)       ← results flow back to JAX as DeviceArrays

The plugin implements two execution models that are interleaved within a single program:

  1. Graph execution — Consecutive ops are batched into an MPSGraph, compiled to a Metal compute pipeline, and dispatched as a single GPU command. This is the primary path for most operations.

  2. Native execution — Performance-critical operations (e.g., Cholesky decomposition via MPSMatrixDecompositionCholesky) bypass the graph compiler and dispatch MPS native kernels directly on MTLBuffer objects.

  3. JIT fusion — Chains of elementwise operations are detected and fused into custom Metal Shading Language (MSL) kernels at runtime, reducing kernel launch overhead and improving memory bandwidth utilization.

Configuration

Environment Variables

Variable Default Description
JAX_PLATFORMS Set to mps to select the Metal backend
JAX_METALLIB_LIBRARY_PATH auto Override path to libpjrt_plugin_silicon.dylib
MPS_LOG_LEVEL 1 Logging verbosity: 0 error, 1 warn, 2 info, 3 debug
JAX_TEST_MODE compare Test mode: compare (CPU vs MPS), mps, or cpu

Library Discovery

The plugin searches for the native library in this order:

  1. JAX_METALLIB_LIBRARY_PATH environment variable
  2. Package directory (editable install)
  3. <package>/lib/ (wheel install)
  4. build/*/lib/ (CMake build directory)
  5. /usr/local/lib/, /opt/homebrew/lib/

Testing

The test suite validates numerical correctness by comparing CPU and MPS results with configurable tolerances.

uv run pytest

JAX_TEST_MODE=mps uv run pytest

uv run pytest -k "unary"
uv run pytest -k "linalg"
uv run pytest -k "flax"

Test Coverage

Category What's tested
Value correctness CPU vs MPS output comparison for all 122 ops
Gradient accuracy jax.grad / jax.value_and_grad for differentiable ops
Edge cases float16 precision, int64 large values, complex numbers
Regressions Catastrophic cancellation (log1p), erf_inv range accuracy
Integration Flax NNX (Linear, Conv, LayerNorm, MultiHeadAttention)
Integration NumPyro probabilistic programming models

Quality Gate

Pre-commit hooks enforce the full quality gate on every commit (MPS is unavailable in GitHub Actions):

  1. clang-format — C/C++/ObjC++ formatting (LLVM style, 110 col)
  2. ruff — Python formatting and linting
  3. build — full native library rebuild
  4. clang-tidy — C++ static analysis (with caching via ctcache)
  5. pytest — full test suite with op coverage enforcement

Benchmarks

uv run python -m benchmarks.bench list

JAX_PLATFORMS=mps uv run python -m benchmarks.bench run --case 'anchor\..*' --platform mps

uv run python -m benchmarks.bench run --case '.*' --json-out results.jsonl

The benchmark suite includes per-op micro-benchmarks, representative anchor workloads, and competitive benchmarks against other backends.

Repository Layout

src/
  jax_plugins/silicon/          Python entrypoint — plugin registration & library discovery
  pjrt_plugin/
    api/                        PJRT C API layer (8 implementation files)
      pjrt_api.cc                 Function pointer table (main entry point)
      pjrt_client.cc              Client: compile, buffer creation, platform info
      pjrt_buffer.cc              Buffer: host↔device transfer, copy, clone
      pjrt_executable.cc          Executable: execute, serialize, output metadata
      pjrt_device.cc              Device: attributes, memory, description
      pjrt_event.cc               Event: async completion, create, set
      pjrt_memory.cc              Memory: kind, addressable devices
      pjrt_topology.cc            Topology: device descriptions, platform
    core/                       Backend core (Objective-C++ / Metal)
      mps_client.h/.mm             Metal device & command queue management
      mps_device.h/.mm             GPU device abstraction
      mps_buffer.h/.mm             MTLBuffer wrapper — host copy, clone, blit
      mps_executable.h/.mm         Execution plan builder & runner
      stablehlo_parser.h/.mm       MLIR StableHLO deserializer
      type_utils.h/.mm             MLIR ↔ MPS type conversions
      completion_event.h           Thread-safe async completion primitives
      pjrt_types.h                 PJRT opaque wrapper structs
      logging.h                    Leveled logging macros
    ops/                        StableHLO op handlers (122 registrations)
      registry.h                  Op registry, handler types, macros
      unary_ops.mm                42 unary operations
      binary_ops.mm               15 binary operations
      shape_ops.mm                18 shape & indexing operations
      reduction_ops.mm            7 reduction operations
      convolution_ops.mm          General dilated convolution
      linalg_ops.mm               Cholesky, triangular solve (native MPS)
      bitwise_ops.mm              8 bitwise operations
      sort_ops.mm                 Sort, top-k
      fft_ops.mm                  FFT / RFFT / IFFT / IRFFT
      control_flow_ops.mm         While, case, custom_call
      random_ops.mm               Threefry / Philox RNG
      tensor_creation_ops.mm      Constant, iota
      collective_ops.mm           Single-device collective no-ops
      higher_order_ops.mm         Map
    runtime/                    JIT Metal kernel engine
      metal_kernels.h/.mm          MSL kernel source cache & pipeline compilation
    proto/
      device_assignment.proto     XLA DeviceAssignmentProto definition
tests/
  test_ops.py                   Parametrized test suite (value + gradient)
  test_int64_constant_splats.py Regression: int64 precision above 2^53
  configs/                      Per-category test configurations (15 modules)
benchmarks/
  bench.py                      Benchmark harness (list/run, JSONL output)
  bottlenecks.py                Per-op micro-benchmarks
  vs_jax_mps.py                 Competitive benchmarks
  anchors.py                    Representative anchor workloads
scripts/
  setup_deps.sh                 One-time native dependency bootstrap

Contributing

See CONTRIBUTING.md for the full guide. The short version:

brew install cmake ninja
./scripts/setup_deps.sh
uv sync --all-groups
uv pip install -e .
pre-commit install

uv pip install -e .
uv run pytest

Acknowledgements

This project draws significant inspiration from MLX by Apple Machine Learning Research. MLX's approach to leveraging Metal and unified memory on Apple Silicon was a major influence on the design and direction of jax-metallib.

License

Copyright 2026 Erfan Zar (@erfanzar). Apache-2.0 — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_metallib-0.9.4.5.tar.gz (272.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_metallib-0.9.4.5-cp313-cp313-macosx_26_0_arm64.whl (6.8 MB view details)

Uploaded CPython 3.13macOS 26.0+ ARM64

File details

Details for the file jax_metallib-0.9.4.5.tar.gz.

File metadata

  • Download URL: jax_metallib-0.9.4.5.tar.gz
  • Upload date:
  • Size: 272.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for jax_metallib-0.9.4.5.tar.gz
Algorithm Hash digest
SHA256 1444757a4e5dd6157f17ccc09a36affa972521e6fbf79ea49c44046e55b027dd
MD5 b9901aac4e63af729f6beb501cd92041
BLAKE2b-256 ecc93a502351f36fe6e708e5de42dabfdf673ca30dd6d7bce9355b26a3697d72

See more details on using hashes here.

File details

Details for the file jax_metallib-0.9.4.5-cp313-cp313-macosx_26_0_arm64.whl.

File metadata

  • Download URL: jax_metallib-0.9.4.5-cp313-cp313-macosx_26_0_arm64.whl
  • Upload date:
  • Size: 6.8 MB
  • Tags: CPython 3.13, macOS 26.0+ ARM64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for jax_metallib-0.9.4.5-cp313-cp313-macosx_26_0_arm64.whl
Algorithm Hash digest
SHA256 ddd4aa76def4eb00630de5cf7bc8d13e33c7c2d4939b209e482b66ff66f62e80
MD5 65d325576b823bd0bedab17af69db17e
BLAKE2b-256 e8214fdf355e73c1f00449a1e44b84b81846eb3b95ba43c7c1431d5c19ef083d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page