Skip to main content

JAX acceleration for Mac GPUs.

Project description

jax-metal

The jax-metal package is a Metal GPU plugin to provide Metal acceleration on Mac platforms for JAX applications.

Release Notes

v0.1.1

The patch fixes IR ops regarding to TopKOp, TanOp, ErfOp, which have changed since jaxlib-0.4.30.

Installation

The following table tracks jax-metal versions and compatible versions of jax, jaxlib and MacOS.

jax-metal MacOS jaxlib jax
0.1.1 Sonoma 14.4+ >=0.4.34 ==jaxlib
0.1.0 Sonoma 14.4+ >=0.4.26 >=0.4.26
0.0.7 Sonoma 14.4+ >=0.4.26 >=0.4.26
0.0.6 Sonoma 14.4 Beta >=v0.4.22, <v0.4.24 >=v0.4.22
0.0.5 Sonoma 14.2+ >=v0.4.20, <v0.4.22 >=v0.4.20
0.0.4 Sonoma 14.0+ v0.4.11 v0.4.11
0.0.3 Ventura 13.4.1+, Sonoma 14.0 Beta v0.4.10 v0.4.11

We recommend to install the binary package with venv or conda.

python3 -m venv ~/jax-metal
source ~/jax-metal/bin/activate
python -m pip install -U pip
python -m pip install numpy wheel
python -m pip install jax-metal

Usage

python -c 'import jax; print(jax.numpy.arange(10))'

Compatibility with jaxlib

jax-metal is compatible with the minimal jaxlib version tracked in the above table. It can be compatibly run with jaxlibs beyond the minimum version by setting the environment variable to ENABLE_PJRT_COMPATIBILITY=1.

pip install -U jaxlib jax
ENABLE_PJRT_COMPATIBILITY=1 python -c 'import jax; print(jax.numpy.arange(10))'

Currently not supported

The Metal plug-in is experimental and not all JAX functionality may be supported. Issues that are reported and tracked can be found in the list: https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+metal

Please refer to https://developer.apple.com/metal/jax/ for the full setup and verification.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl (41.2 MB view details)

Uploaded Python 3 macOS 13.0+ ARM64

jax_metal-0.1.1-py3-none-macosx_10_14_x86_64.whl (54.7 MB view details)

Uploaded Python 3 macOS 10.14+ x86-64

File details

Details for the file jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl.

File metadata

File hashes

Hashes for jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl
Algorithm Hash digest
SHA256 f1dbfecb298cdd3ba6da3ad6dc9a2adb63d71741f8b8ece28c296b32d608b6c8
MD5 3435ef5a9a7f2a01218264c67c61ea5d
BLAKE2b-256 09dc6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d

See more details on using hashes here.

File details

Details for the file jax_metal-0.1.1-py3-none-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jax_metal-0.1.1-py3-none-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 d918a78443cb808c9491a24a5c2a94cc4eabfd0461d5bcda29a8f332dfbe9b7e
MD5 81c3cbc8f16c217ad8b125a990941214
BLAKE2b-256 87ec9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page