Skip to main content

JAX backend for Apple M series of chips

Project description

jax-silicon

jax-silicon is a PJRT plugin that enables JAX to run on Apple Metal (MPS) GPUs on Apple Silicon.

Requirements

  • macOS 13+ on Apple Silicon
  • Python 3.11+
  • jax and jaxlib 0.9.x

Install

pip install jax-silicon

Build From Source

brew install cmake ninja
uv pip install -e .

On first build, missing native dependencies are bootstrapped automatically by running scripts/setup_deps.sh.

To disable auto-bootstrap:

CMAKE_ARGS="-DJAX_SILICON_AUTO_SETUP_DEPS=OFF" uv pip install -e .

Use

The plugin backend name in JAX is still mps.

JAX_PLATFORMS=mps python -c "import jax; print(jax.devices())"

Optional library path overrides:

  • JAX_SILICON_LIBRARY_PATH
  • JAX_MPS_LIBRARY_PATH (legacy compatibility)

Test

uv run pytest

Repository Layout

  • src/jax_plugins/silicon/ Python plugin entrypoint
  • src/pjrt_plugin/ C++/Objective-C++ PJRT implementation
  • scripts/setup_deps.sh dependency bootstrap
  • tests/ test suite

License

Apache-2.0.

This repository is a derivative of tillahoffmann/jax-mps.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_metallib-0.9.0.tar.gz (158.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_metallib-0.9.0-cp313-cp313-macosx_26_0_arm64.whl (6.7 MB view details)

Uploaded CPython 3.13macOS 26.0+ ARM64

File details

Details for the file jax_metallib-0.9.0.tar.gz.

File metadata

  • Download URL: jax_metallib-0.9.0.tar.gz
  • Upload date:
  • Size: 158.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for jax_metallib-0.9.0.tar.gz
Algorithm Hash digest
SHA256 b37c235cbe811569f57edd4a24cc327f4cbc714efd755d3c16f0b610b043685e
MD5 9e216c9c77c1941d0da08091da4350de
BLAKE2b-256 16b7e0fbfd19d490134d59fbe8e7d48d136dd3d3487b9f8bb0606d932231bf0c

See more details on using hashes here.

File details

Details for the file jax_metallib-0.9.0-cp313-cp313-macosx_26_0_arm64.whl.

File metadata

  • Download URL: jax_metallib-0.9.0-cp313-cp313-macosx_26_0_arm64.whl
  • Upload date:
  • Size: 6.7 MB
  • Tags: CPython 3.13, macOS 26.0+ ARM64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.21 {"installer":{"name":"uv","version":"0.9.21","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for jax_metallib-0.9.0-cp313-cp313-macosx_26_0_arm64.whl
Algorithm Hash digest
SHA256 28e147d043b447e2911e6fdccf26551efe3f69b8efd79eca1d512dee323a149a
MD5 b2e6036cd3d6e9c44ea0df9f8764da8d
BLAKE2b-256 f527278ad4711081b09d5c30d8be1fe3bde857cd46957f708ad8ae5814689579

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page