High-level array sharding API for JAX
Project description
Einshard
High-level array sharding API for JAX
Introduction
TODO: Add detailed introduction.
This project originated as a part of the Mistral 7B v0.2 JAX project and has since evolved into an independent project.
This project is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).
Installation
This library requires at least Python 3.12.
pip install einshard
You need to have JAX installed by choosing the correct installation method before installing Einshard.
Usage
For testing purpose, we initialise the JAX CPU backend with 16 devices. This should be run before the actual code (e.g. placed at the top of the script):
n_devices = 16
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + f' --xla_force_host_platform_device_count={n_devices}'
Code:
from einshard import einshard
import jax
import jax.numpy as jnp
a = jnp.zeros((4, 8))
a = einshard(a, 'a b -> * a* b2*')
jax.debug.visualize_array_sharding(a)
Output:
┌──────────┬──────────┬──────────┬──────────┐
│ │ │ │ │
│ CPU 0,8 │ CPU 1,9 │ CPU 2,10 │ CPU 3,11 │
│ │ │ │ │
│ │ │ │ │
├──────────┼──────────┼──────────┼──────────┤
│ │ │ │ │
│ CPU 4,12 │ CPU 5,13 │ CPU 6,14 │ CPU 7,15 │
│ │ │ │ │
│ │ │ │ │
└──────────┴──────────┴──────────┴──────────┘
Development
Crente venv:
python3.12 -m venv venv
. venv/bin/activate
Install dependencies:
pip install -U pip
pip install -U wheel
pip install "jax[cpu]"
pip install -r requirements.txt
Run test:
python tests/test_einshard.py
Build package:
pip install build
python -m build
Build docs:
cd docs
make html
cd docs/_build/html
python -m http.server -b 127.0.0.1
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file einshard-0.2.0.tar.gz.
File metadata
- Download URL: einshard-0.2.0.tar.gz
- Upload date:
- Size: 10.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
07178559498da85da1a44a3ee70b9ef3a5acc1b993a35ab76c26f52cc2a83a8b
|
|
| MD5 |
37e6e437afd1de8fb85e6054ba7e5374
|
|
| BLAKE2b-256 |
b967c9522b2ac2ab1f2cc8b4c62f446ead137a20370172ff2fbe1b03d47ba7be
|
File details
Details for the file einshard-0.2.0-py3-none-any.whl.
File metadata
- Download URL: einshard-0.2.0-py3-none-any.whl
- Upload date:
- Size: 10.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4a059d40a266e843826e0a4168613d98c62ea8aedee29e5ce9b9f9e1d0fa3fd3
|
|
| MD5 |
4f8e836a34b43be3ee1ab1da753140d7
|
|
| BLAKE2b-256 |
e59c11d85aa5e5fe9cefda96c9bcdbaa7008c1c416ea47d5b3ce78a22410b65b
|