Skip to main content

JAX PJRT plugin for Apple Silicon using MLX

Project description

JAX MLX Plugin

PyPI Python 3.11+ License: MIT

A PJRT plugin that lets JAX run on Apple Silicon GPUs via MLX. Write standard JAX code — the plugin handles compilation to Metal compute kernels automatically.

Requirements

  • Apple Silicon Mac (M1/M2/M3/M4)
  • macOS 14.0+ (Sonoma)
  • Python 3.11+

Installation

pip install jax-mlx-plugin

Or from source:

git clone https://github.com/tsumme1/jax-mlx.git
cd jax-mlx
pip install .

Quick Start

import jax
import jax.numpy as jnp

print(jax.devices())  # [MlxDevice(id=0)]

mlx = jax.devices('mlx')[0]
with jax.default_device(mlx):
    x = jnp.array([1.0, 2.0, 3.0])
    y = jnp.sin(x) + jnp.cos(x)
    print(y)  # runs on Metal GPU

What Works

Category Details
Core ops Arithmetic, math, reductions, comparisons, bitwise, type conversion
Autodiff jax.grad, value_and_grad, jacfwd, jacrev, hessian
Transforms jax.jit, jax.vmap
Control flow lax.cond, lax.switch, lax.while_loop, lax.scan, lax.fori_loop, lax.map, lax.associative_scan
Linear algebra matmul, solve, inv, cholesky, qr, svd, eig, eigh, triangular_solve, slogdet
Neural networks Flax + Optax (CNNs, MLPs, RNNs, Transformers verified)
Convolutions conv_general_dilated (NHWC/NCHW), pooling (max/min/avg + gradients)
FFT fft, ifft, rfft, irfft, 2D variants
Distributions jax.random.* (Threefry PRNG with 64-bit emulation on Metal)
SciPy scipy.special, scipy.linalg, scipy.stats, scipy.signal, scipy.ndimage

See ARCHITECTURE.md for technical details.

Benchmarks

Four benchmark suites compare JAX-MLX against JAX CPU and native MLX:

Benchmark Command
CNN (Conv + Pool + Dense) python benchmarks/benchmark_cnn.py
MLP (Dense layers) python benchmarks/benchmark_mlp.py
RNN (Recurrent) python benchmarks/benchmark_rnn.py
Transformer (Attention) python benchmarks/benchmark_transformer.py

Each also has a _native.py variant for direct MLX comparison.

Testing

# Exhaustive op coverage (387 ops)
python tests/test_exhaustive.py

# Every op wrapped in lax.while_loop (362 ops)
python tests/test_exhaustive_while.py

# Compilation tier coverage
python tests/test_compilation_coverage.py

Known Limitations

  • Float64 — Not natively supported on Metal; use Float32
  • While loops — Block kernel fusion for the enclosing graph (segments within are still compiled)
  • LAPACK ops — LU factorization, slogdet use CPU interpreter fallback

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_mlx_plugin-0.0.4.tar.gz (177.6 kB view details)

Uploaded Source

File details

Details for the file jax_mlx_plugin-0.0.4.tar.gz.

File metadata

  • Download URL: jax_mlx_plugin-0.0.4.tar.gz
  • Upload date:
  • Size: 177.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.2

File hashes

Hashes for jax_mlx_plugin-0.0.4.tar.gz
Algorithm Hash digest
SHA256 cb26c3ad78bed78537f260c233e89233edfb6c785ca2ead7c275a420ee77f289
MD5 3d7834b9c0e1d67d0367a70ccfcf779c
BLAKE2b-256 9eb1d20a02f133eabfcb7124adb268a9eece24a208ffaa4aca5a4d53bdf1da65

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page