Automatically initialize distributed PyTorch environments
Project description
torchrunx 🔥
By Apoorv Khandelwal and Peter Curtin
Automatically distribute PyTorch functions onto multiple machines or GPUs
Installation
pip install torchrunx
Requires: Linux (with shared filesystem & SSH access if using multiple machines)
Demo
Here's a simple example where we "train" a model on two nodes (with 2 GPUs each).
Training code
import os
import torch
def train():
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
model = torch.nn.Linear(10, 10).to(local_rank)
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
optimizer = torch.optim.AdamW(ddp_model.parameters())
optimizer.zero_grad()
outputs = ddp_model(torch.randn(5, 10))
labels = torch.randn(5, 10).to(local_rank)
torch.nn.functional.mse_loss(outputs, labels).backward()
optimizer.step()
if rank == 0:
return model
You could also use transformers.Trainer
(or similar) to automatically handle all the multi-GPU / DDP code above.
import torchrunx as trx
if __name__ == "__main__":
result = trx.launch(
func=train,
hostnames=["localhost", "other_node"],
workers_per_host=2 # number of GPUs
)
trained_model = result.rank(0)
torch.save(trained_model.state_dict(), "model.pth")
Full API
Advanced Usage
Why should I use this?
Whether you have 1 GPU, 8 GPUs, or 8 machines:
Features
- Our
launch()
utility is super Pythonic- Return objects from your workers
- Run
python script.py
instead oftorchrun script.py
- Launch multi-node functions, even from Python Notebooks
- Fine-grained control over logging, environment variables, exception handling, etc.
- Automatic integration with SLURM
Robustness
- If you want to run a complex, modular workflow in one script
- don't parallelize your entire script: just the functions you want!
- no worries about memory leaks or OS failures
Convenience
- If you don't want to:
- set up
dist.init_process_group
yourself - manually SSH into every machine and
torchrun --master-ip --master-port ...
, babysit failed processes, etc.
- set up
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchrunx-0.2.0.tar.gz
(47.8 kB
view details)
Built Distribution
torchrunx-0.2.0-py3-none-any.whl
(17.5 kB
view details)
File details
Details for the file torchrunx-0.2.0.tar.gz
.
File metadata
- Download URL: torchrunx-0.2.0.tar.gz
- Upload date:
- Size: 47.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9838e74a2c6eea8b5446428c58baf804a9f8128558f9969b53a637948fd2666a |
|
MD5 | a4939f5d124edb54d6cfccf230612266 |
|
BLAKE2b-256 | aec2843c9430f9abf805a9136179e890f1a65196b730a1a6db65588335250f99 |
File details
Details for the file torchrunx-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: torchrunx-0.2.0-py3-none-any.whl
- Upload date:
- Size: 17.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 69fbf0fc600868c6eb778c6cfbdbc90231a11f95a01a0d8fbefb0d30ee84dc7a |
|
MD5 | e90c4a9cab7391adbef7df560aeeff4d |
|
BLAKE2b-256 | f84425413b79696d5d95864df86d3ed7d229b8578c27ed714d6384172bf757bc |