JAX acceleration for Mac GPUs.
Project description
jax-metal
The jax-metal package is a Metal GPU plugin to provide Metal acceleration on Mac platforms for JAX applications.
Release Notes
v0.1.1
The patch fixes IR ops regarding to TopKOp, TanOp, ErfOp, which have changed since jaxlib-0.4.30.
Installation
The following table tracks jax-metal
versions and compatible versions of jax
, jaxlib
and MacOS
.
jax-metal | MacOS | jaxlib | jax |
---|---|---|---|
0.1.1 | Sonoma 14.4+ | >=0.4.34 | ==jaxlib |
0.1.0 | Sonoma 14.4+ | >=0.4.26 | >=0.4.26 |
0.0.7 | Sonoma 14.4+ | >=0.4.26 | >=0.4.26 |
0.0.6 | Sonoma 14.4 Beta | >=v0.4.22, <v0.4.24 | >=v0.4.22 |
0.0.5 | Sonoma 14.2+ | >=v0.4.20, <v0.4.22 | >=v0.4.20 |
0.0.4 | Sonoma 14.0+ | v0.4.11 | v0.4.11 |
0.0.3 | Ventura 13.4.1+, Sonoma 14.0 Beta | v0.4.10 | v0.4.11 |
We recommend to install the binary package with venv or conda.
python3 -m venv ~/jax-metal
source ~/jax-metal/bin/activate
python -m pip install -U pip
python -m pip install numpy wheel
python -m pip install jax-metal
Usage
python -c 'import jax; print(jax.numpy.arange(10))'
Compatibility with jaxlib
jax-metal is compatible with the minimal jaxlib version tracked in the above table. It can be compatibly run with jaxlibs beyond the minimum version by setting the environment variable to ENABLE_PJRT_COMPATIBILITY=1.
pip install -U jaxlib jax
ENABLE_PJRT_COMPATIBILITY=1 python -c 'import jax; print(jax.numpy.arange(10))'
Currently not supported
The Metal plug-in is experimental and not all JAX functionality may be supported. Issues that are reported and tracked can be found in the list: https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+metal
- Unsupported data types: np.float64, np.complex64, np.complex128
- The Metal plug-in doesn’t pass all tests under https://github.com/google/jax/tree/main/tests.
Please refer to https://developer.apple.com/metal/jax/ for the full setup and verification.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
File details
Details for the file jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl
.
File metadata
- Download URL: jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl
- Upload date:
- Size: 41.2 MB
- Tags: Python 3, macOS 13.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.14
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f1dbfecb298cdd3ba6da3ad6dc9a2adb63d71741f8b8ece28c296b32d608b6c8 |
|
MD5 | 3435ef5a9a7f2a01218264c67c61ea5d |
|
BLAKE2b-256 | 09dc6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d |
File details
Details for the file jax_metal-0.1.1-py3-none-macosx_10_14_x86_64.whl
.
File metadata
- Download URL: jax_metal-0.1.1-py3-none-macosx_10_14_x86_64.whl
- Upload date:
- Size: 54.7 MB
- Tags: Python 3, macOS 10.14+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.14
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d918a78443cb808c9491a24a5c2a94cc4eabfd0461d5bcda29a8f332dfbe9b7e |
|
MD5 | 81c3cbc8f16c217ad8b125a990941214 |
|
BLAKE2b-256 | 87ec9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f |