Skip to main content

High-level array sharding API for JAX

Project description

Einshard

High-level array sharding API for JAX

Introduction

TODO: Add detailed introduction.

This project originated as a part of the Mistral 7B v0.2 JAX project and has since evolved into an independent project.

This project is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).

Installation

This library requires at least Python 3.12.

pip install einshard

You need to have JAX installed by choosing the correct installation method before installing Einshard.

Usage

For testing purpose, we initialise the JAX CPU backend with 16 devices. This should be run before the actual code (e.g. placed at the top of the script):

n_devices = 16
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + f' --xla_force_host_platform_device_count={n_devices}'

Code:

from einshard import einshard 
import jax
import jax.numpy as jnp

a = jnp.zeros((4, 8))
a = einshard(a, 'a b -> * a* b2*')
jax.debug.visualize_array_sharding(a)

Output:

┌──────────┬──────────┬──────────┬──────────┐
│          │          │          │          │
│ CPU 0,8  │ CPU 1,9  │ CPU 2,10 │ CPU 3,11 │
│          │          │          │          │
│          │          │          │          │
├──────────┼──────────┼──────────┼──────────┤
│          │          │          │          │
│ CPU 4,12 │ CPU 5,13 │ CPU 6,14 │ CPU 7,15 │
│          │          │          │          │
│          │          │          │          │
└──────────┴──────────┴──────────┴──────────┘

Development

Crente venv:

python3.12 -m venv venv
. venv/bin/activate

Install dependencies:

pip install -U pip
pip install -U wheel
pip install "jax[cpu]"
pip install -r requirements.txt

Run test:

python tests/test_einshard.py

Build package:

pip install build
python -m build

Build docs:

cd docs
make html
cd docs/_build/html
python -m http.server -b 127.0.0.1

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

einshard-0.2.0.tar.gz (10.9 kB view details)

Uploaded Source

Built Distribution

einshard-0.2.0-py3-none-any.whl (10.4 kB view details)

Uploaded Python 3

File details

Details for the file einshard-0.2.0.tar.gz.

File metadata

  • Download URL: einshard-0.2.0.tar.gz
  • Upload date:
  • Size: 10.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.19

File hashes

Hashes for einshard-0.2.0.tar.gz
Algorithm Hash digest
SHA256 07178559498da85da1a44a3ee70b9ef3a5acc1b993a35ab76c26f52cc2a83a8b
MD5 37e6e437afd1de8fb85e6054ba7e5374
BLAKE2b-256 b967c9522b2ac2ab1f2cc8b4c62f446ead137a20370172ff2fbe1b03d47ba7be

See more details on using hashes here.

File details

Details for the file einshard-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: einshard-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 10.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.19

File hashes

Hashes for einshard-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4a059d40a266e843826e0a4168613d98c62ea8aedee29e5ce9b9f9e1d0fa3fd3
MD5 4f8e836a34b43be3ee1ab1da753140d7
BLAKE2b-256 e59c11d85aa5e5fe9cefda96c9bcdbaa7008c1c416ea47d5b3ce78a22410b65b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page