High-level array sharding API for JAX
Project description
Einshard
High-level array sharding API for JAX
Introduction
TODO: Add detailed introduction.
This project originated as a part of the Mistral 7B v0.2 JAX project and has since evolved into an independent project.
This project is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).
Installation
This library requires at least Python 3.12.
pip install einshard
You need to have JAX installed by choosing the correct installation method before installing Einshard.
Usage
For testing purpose, we initialise the JAX CPU backend with 16 devices. This should be run before the actual code (e.g. placed at the top of the script):
n_devices = 16
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + f' --xla_force_host_platform_device_count={n_devices}'
Code:
from einshard import einshard
import jax
import jax.numpy as jnp
a = jnp.zeros((4, 8))
a = einshard(a, 'a b -> * a* b2*')
jax.debug.visualize_array_sharding(a)
Output:
┌──────────┬──────────┬──────────┬──────────┐
│ │ │ │ │
│ CPU 0,8 │ CPU 1,9 │ CPU 2,10 │ CPU 3,11 │
│ │ │ │ │
│ │ │ │ │
├──────────┼──────────┼──────────┼──────────┤
│ │ │ │ │
│ CPU 4,12 │ CPU 5,13 │ CPU 6,14 │ CPU 7,15 │
│ │ │ │ │
│ │ │ │ │
└──────────┴──────────┴──────────┴──────────┘
Development
Crente venv:
python3.12 -m venv venv
. venv/bin/activate
Install dependencies:
pip install -U pip
pip install -U wheel
pip install "jax[cpu]"
pip install -r requirements.txt
Run test:
python tests/test_einshard.py
Build package:
pip install build
python -m build
Build docs:
cd docs
make html
cd docs/_build/html
python -m http.server -b 127.0.0.1
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file einshard-0.2.0.tar.gz
.
File metadata
- Download URL: einshard-0.2.0.tar.gz
- Upload date:
- Size: 10.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 07178559498da85da1a44a3ee70b9ef3a5acc1b993a35ab76c26f52cc2a83a8b |
|
MD5 | 37e6e437afd1de8fb85e6054ba7e5374 |
|
BLAKE2b-256 | b967c9522b2ac2ab1f2cc8b4c62f446ead137a20370172ff2fbe1b03d47ba7be |
File details
Details for the file einshard-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: einshard-0.2.0-py3-none-any.whl
- Upload date:
- Size: 10.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4a059d40a266e843826e0a4168613d98c62ea8aedee29e5ce9b9f9e1d0fa3fd3 |
|
MD5 | 4f8e836a34b43be3ee1ab1da753140d7 |
|
BLAKE2b-256 | e59c11d85aa5e5fe9cefda96c9bcdbaa7008c1c416ea47d5b3ce78a22410b65b |