JAX PJRT plugin for Apple Silicon using MLX
Project description
JAX MLX Plugin
A PJRT plugin that lets JAX run on Apple Silicon GPUs via MLX. Write standard JAX code — the plugin handles compilation to Metal compute kernels automatically.
Requirements
- Apple Silicon Mac (M1/M2/M3/M4)
- macOS 14.0+ (Sonoma)
- Python 3.11+
Installation
pip install jax-mlx-plugin
Or from source:
git clone https://github.com/tsumme1/jax-mlx.git
cd jax-mlx
pip install .
Quick Start
import jax
import jax.numpy as jnp
print(jax.devices()) # [MlxDevice(id=0)]
mlx = jax.devices('mlx')[0]
with jax.default_device(mlx):
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.sin(x) + jnp.cos(x)
print(y) # runs on Metal GPU
What Works
| Category | Details |
|---|---|
| Core ops | Arithmetic, math, reductions, comparisons, bitwise, type conversion |
| Autodiff | jax.grad, value_and_grad, jacfwd, jacrev, hessian |
| Transforms | jax.jit, jax.vmap |
| Control flow | lax.cond, lax.switch, lax.while_loop, lax.scan, lax.fori_loop, lax.map, lax.associative_scan |
| Linear algebra | matmul, solve, inv, cholesky, qr, svd, eig, eigh, triangular_solve, slogdet |
| Neural networks | Flax + Optax (CNNs, MLPs, RNNs, Transformers verified) |
| Convolutions | conv_general_dilated (NHWC/NCHW), pooling (max/min/avg + gradients) |
| FFT | fft, ifft, rfft, irfft, 2D variants |
| Distributions | jax.random.* (Threefry PRNG with 64-bit emulation on Metal) |
| SciPy | scipy.special, scipy.linalg, scipy.stats, scipy.signal, scipy.ndimage |
See ARCHITECTURE.md for technical details.
Benchmarks
Four benchmark suites compare JAX-MLX against JAX CPU and native MLX:
| Benchmark | Command |
|---|---|
| CNN (Conv + Pool + Dense) | python benchmarks/benchmark_cnn.py |
| MLP (Dense layers) | python benchmarks/benchmark_mlp.py |
| RNN (Recurrent) | python benchmarks/benchmark_rnn.py |
| Transformer (Attention) | python benchmarks/benchmark_transformer.py |
Each also has a _native.py variant for direct MLX comparison.
Testing
# Exhaustive op coverage (387 ops)
python tests/test_exhaustive.py
# Every op wrapped in lax.while_loop (362 ops)
python tests/test_exhaustive_while.py
# Compilation tier coverage
python tests/test_compilation_coverage.py
Known Limitations
- Float64 — Not natively supported on Metal; use Float32
- While loops — Block kernel fusion for the enclosing graph (segments within are still compiled)
- LAPACK ops — LU factorization, slogdet use CPU interpreter fallback
License
MIT — see LICENSE.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax_mlx_plugin-0.0.4.tar.gz
(177.6 kB
view details)
File details
Details for the file jax_mlx_plugin-0.0.4.tar.gz.
File metadata
- Download URL: jax_mlx_plugin-0.0.4.tar.gz
- Upload date:
- Size: 177.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cb26c3ad78bed78537f260c233e89233edfb6c785ca2ead7c275a420ee77f289
|
|
| MD5 |
3d7834b9c0e1d67d0367a70ccfcf779c
|
|
| BLAKE2b-256 |
9eb1d20a02f133eabfcb7124adb268a9eece24a208ffaa4aca5a4d53bdf1da65
|