Alpa automatically parallelizes large tensor computation graphs and runs them on a distributed cluster.
Project description
Alpa
Alpa is a system for training large-scale neural networks. Scaling neural networks to hundreds of billions of parameters has enabled dramatic breakthroughs such as GPT-3, but training these large-scale neural networks requires complicated distributed training techniques. Alpa aims to automate large-scale distributed training with just a few lines of code.
The key features of Alpa include:
💻 Automatic Parallelization. Alpa automatically parallelizes users' single-device code on distributed clusters with data, operator, and pipeline parallelism.
🚀 Excellent Performance. Alpa achieves linear scaling on training models with billions of parameters on distributed clusters.
✨ Tight Integration with Machine Learning Ecosystem. Alpa is backed by open-source, high-performance, and production-ready libraries such as Jax, XLA, and Ray
Quick Start
Use Alpa's decorator @parallelize
to scale your single-device training code to distributed clusters.
import alpa
# Parallelize the training step in Jax by simply using a decorator
@alpa.parallelize
def train_step(model_state, batch):
def loss_func(params):
out = model_state.forward(params, batch["x"])
return jnp.mean((out - batch["y"]) ** 2)
grads = grad(loss_func)(model_state.params)
new_model_state = model_state.apply_gradient(grads)
return new_model_state
# The training loop now automatically runs on your designated cluster
model_state = create_train_state()
for batch in data_loader:
model_state = train_step(model_state, batch)
Check out the Alpa Documentation site for installation instructions, tutorials, examples, and more.
More Information
- Alpa paper (OSDI'22)
- Google AI Blog
- Alpa talk slides
Getting Involved
- Please read the contributor guide if you are interested in contributing to Alpa.
- Please connect to Alpa contributors via the Alpa slack.
License
Alpa is licensed under the Apache-2.0 license.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
Hashes for alpa-0.1.4-cp39-cp39-manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 66536e57eba76aa5cff2eaf4b4f0517f865c48dded5e8d143caa830f765bab4a |
|
MD5 | 2f32592d05378738a27d3ff9119a5998 |
|
BLAKE2b-256 | e3bdaeb6042abd5bc5bf6e803225eedb1fd296a4b7e99f0a7039d201925614de |
Hashes for alpa-0.1.4-cp38-cp38-manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5c75902b1b54a12dec76f1b68dc100f336c347f51ebd1381875ced1d6dc4a4b7 |
|
MD5 | 8704edd13fe5875170043a8ccc885616 |
|
BLAKE2b-256 | 5ea44586e10fc7f31ed2432e9b3e2719b13550e78b6ca666dfd0cf401a101514 |
Hashes for alpa-0.1.4-cp37-cp37m-manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4b35fd90a99e07f13926dacb8a4dbc6df1a0c3d51bf7d5c90ab4048f266bd4c3 |
|
MD5 | c4da95e68401c416e03ff5c0a85d591e |
|
BLAKE2b-256 | 7233f9396b56efe9999f2439a9018a1046b9e43d6f4ef36c60b490174898aca9 |