Decouple Torch Network-Aware Training on Interlinked Online Nodes (DeToNATION)
Project description
Decoupled Torch Network-Aware Training on Interlinked Online Nodes (DeToNATION)
This code currently implements the results described in FlexDeMo: Decoupled Momentum Optimization for Fully and Hybrid Sharded Training. An implementation to run all experiments from the paper is found in the benchmarks folder.
Installation
Installation from PyPI:
pip install detonation
Installation from source:
git clone https://github.com/schneiderkamplab/DeToNATION
cd DeToNATION
pip install .
Example
There is a a full example for language model training using FlexDeMo in the example folder. Please refer to the documentation:
examples/t5/README.md
This example demonstrates the use of the prepare_detonation function for obtaining a distributed model and optimizer.
Benchmarks
There is a a full benchmarking example for language model training using FlexDeMo in the benchmarks folder. Please refer to the documentation:
benchmarks/t5/README.md
This benchmarking example demonstrates the use of the prepare_detonation function for obtaining a distributed model and optimizer, and uses aim and mltiming to track model parameters and performance.
Usage
The direct usage of DeToNATION without using prepare_detonation requires three elements as exemplified below for the FlexDeMo optimizer, i.e., DeToNATION with node-based hybrid sharding using DeMo replication.
First, you need to wrap your model with FSDP and the hybrid sharding strategy:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(
model,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
Then, you can import and instantiate the FlexDeMo optimizer:
from detonation import DeMo
optim = DeMo(
compression_topk=16,
compression_chunk=128,
sharding_parallel_group=model.process_group,
replication_parallel_group=model._inter_node_pg,
)
Third and last, you need to wrap the forward and backward pass using a
no_sync context manager to avoid automatic full gradient synchronization:
with model.no_sync(): # Disable gradient synchronizations across FSDP instances.
loss = model(input_ids=batch["input_ids"],labels=batch["labels"])["loss"]
loss.backward()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file detonation-0.5.2.tar.gz.
File metadata
- Download URL: detonation-0.5.2.tar.gz
- Upload date:
- Size: 19.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e45991e830fc2b0b13ec0fb82793abd8a263a1609002861daa7b4ab9984432af
|
|
| MD5 |
c56f9ce33bcfddd36850dbb842c9c0c9
|
|
| BLAKE2b-256 |
cd302a23d945d78083166efe8e690e6f89e5ba12d716e0aa5859511d87392e88
|
File details
Details for the file detonation-0.5.2-py3-none-any.whl.
File metadata
- Download URL: detonation-0.5.2-py3-none-any.whl
- Upload date:
- Size: 34.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
321584a22ca65df69164b32cfb582e1e48044fc39cef428016c26a39b8ba4329
|
|
| MD5 |
7149ac67100768ec6d25c1d78c8e69f5
|
|
| BLAKE2b-256 |
2518c7f824dd3aef36c71ce36801c0bd32b875a5de0245ba8c9124f0c659d920
|