Skip to main content

High-level array sharding API for JAX

Project description

Einshard

High-level array sharding API for JAX

Introduction

TODO: Add detailed introduction.

This project originated as a part of the Mistral 7B v0.2 JAX project and has since evolved into an independent project.

This project is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).

Installation

This library requires at least Python 3.12.

pip install einshard

Usage

For testing purpose, we initialise the JAX CPU backend with 16 devices. This should be run before the actual code (e.g. placed at the top of the script):

n_devices = 16
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + f' --xla_force_host_platform_device_count={n_devices}'

Code:

import einshard 
import jax
import jax.numpy as jnp

a = jnp.zeros((4, 8))
a = einshard.shard(a, 'a b -> * a* b2*')
jax.debug.visualize_array_sharding(a)

Output:

┌──────────┬──────────┬──────────┬──────────┐
│          │          │          │          │
│ CPU 0,8  │ CPU 1,9  │ CPU 2,10 │ CPU 3,11 │
│          │          │          │          │
│          │          │          │          │
├──────────┼──────────┼──────────┼──────────┤
│          │          │          │          │
│ CPU 4,12 │ CPU 5,13 │ CPU 6,14 │ CPU 7,15 │
│          │          │          │          │
│          │          │          │          │
└──────────┴──────────┴──────────┴──────────┘

Development

python3.12 -m venv venv
. venv/bin/activate
pip install -U pip
pip install -U wheel
pip install "jax[cpu]"
pip install -r requirements.txt

Run test:

python tests/test_einshard.py

Build package:

pip install build
python -m build

Build docs:

cd docs
make html
cd docs/_build/html
python -m http.server -b 127.0.0.1

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

einshard-0.1.0.tar.gz (10.7 kB view hashes)

Uploaded Source

Built Distribution

einshard-0.1.0-py3-none-any.whl (10.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page