Skip to main content

Slapo: A Schedule Language for Progressive Optimization.

Project description

Slapo: A Schedule Language for Large Model Training

Documentation

GitHub CI-Lass-Pass

Slapo is a schedule language for progressive optimization of large deep learning model training.

Large deep learning models demonstrate dominating model accuracy on a range of tasks in NLP and CV, but it is hard to train the model efficiently while preserving the usability. Slapo aims to address this tension through separation of concerns. Slapo decouples model execution from definition, enabling developers to use a set of schedule primitives to convert a PyTorch model for common model training optimizations without directly changing the model itself.

Slapo highlights the following features:

:rocket: Progressive optimization. Slapo incorporates a "trace by need" approach that only traces a desired module to be a static graph for compiler-based aggressive optimizations.

:building_construction: Structure-preserving scheduling. Slapo preserves the module hierarchy when constructing the schedule, so developers can easily locate the module and apply scheduling, which also facilitates the users to debug any performance and convergence issue.

:gear: Auto-tuning. Slapo provides a programming interface that allows developers to specify a set of tuneable knobs to form an efficient tuning space, which can then be explored by Slapo auto-tuner to realize the optimal configuration.

Getting Started

Installation

There are two approaches to install Slapo:

  1. Install from PYPI
pip3 install slapo
  1. Install from source
git clone https://github.com/awslabs/slapo.git slapo
cd slapo
pip3 install -e ".[dev]"

In addition, you can optionally install HuggingFace Transformers (>= v4.25.1) to retrieve models. Also, Slapo currently supports the following frameworks, so you can run the scheduled models on these frameworks if needed.

Usage

Please see the examples folder for more details. Documentations will be released soon.

import slapo

# Load a PyTorch model from HuggingFace Hub, TorchVision, etc.
from transformers import BertLMHeadModel, AutoConfig
config = AutoConfig.from_pretrained("bert-large-uncased")
bert = BertLMHeadModel(config)

# Create a default schedule
sch = slapo.create_schedule(bert)

# Apply primitives to optimize the model
# Please refer to examples/bert/schedule.py for details
sch["bert.encoder.layer.0"].primitve(...)

# Build an optimized model
opt_model = slapo.build(sch)

# Run the optimized model
inputs = ...
outputs = opt_model(inputs)

Supported Primitives

To maximally reduce the risk introduced by tracers and compilers, we leverage progressive optimization to gradually apply primitives to a part of the model. We classify the primitives into two categories. The first type of primitives does not require tracing and can be directly applied to modules and parameters; the second type of primitives requires a static graph, and thus needs to apply the .trace() primitive first.

We provide the following primitives for dynamic graph optimizations:

Feature Primitive
Module replacement s[op].replace(new_module)
Tensor parallelism s[op].shard(param_name, axis)
Synchronization s[op].sync(mode="fwd_pre/fwd_post/bwd_post", sync_op_or_fn, **kwargs)
Checkpointing s[op].checkpoint()
Fork random number generator s[op].fork_rng()
Annotate parameters s[op].annotate(param_name, key, value)

And the following primitives for static graph optimizations:

Feature Primitive
Module Tracing s.trace(leaves, flatten)
Pattern matching s.find(regex_or_pattern_fn)
Operator fusion s[op].fuse(compiler, subgraph)
Layer decomposition s[op].decompose()
Partial module replacement s[op].replace(new_module, subgraph)
Partial gradient checkpointing s[op].checkpoint(subgraph)
Pipeline parallelism s[op].cut_pipeline_stage()

You can look for all supported primitvies with the following API:

import slapo
print(slapo.list_primitives())

You could also check the description of each primitive on the fly:

import slapo
help(slapo.list_primitives(name_only=False)["shard"])

Auto-Tuning

We also provide a light-weight interface for auto-tuning, so the developers can (1) construct a polyhedral search space using our APIs, and (2) leverage Slapo auto-tuner to automatically search for the best training configuration.

cd benchmark
# Single device
# The following script will trigger the tuning jobs for all the models
python3 tune_single_device.py
# Single node
python3 tune_single_node.py

Benchmarking

We provide scripts to reproduce our results on a single AWS EC2 p3.16xlarge node with 8 * V100 GPUs.

cd benchmark
# Download datasets
bash download_benchmark_dataset.sh
# Run benchmarking
# Megatron-LM and Deepspeed are required for executing the experiments
bash run_all_single_node.sh config/single_node_v100.cfg

Publication

If you use Slapo in your project, please consult authors for citation.

License

Slapo is released under the Apache 2.0 license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

slapo-0.0.3-py3-none-any.whl (145.1 kB view details)

Uploaded Python 3

File details

Details for the file slapo-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: slapo-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 145.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.0

File hashes

Hashes for slapo-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 47fa9a0289bec7751efd446cd1691653e309debdf55cafd9436e6b5ecd77ae77
MD5 eebfdd8b1f32a40bb69dc53da27182b6
BLAKE2b-256 035ff27936475e5a2ee01506904a8e57b477ab0cdf5f0a1995311ce0a9001f55

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page