Automatically initialize distributed PyTorch environments
Project description
torchrunx 🔥
By Apoorv Khandelwal and Peter Curtin
Automatically distribute PyTorch functions onto multiple machines or GPUs
Installation
pip install torchrunx
Requires: Linux (with shared filesystem & SSH access if using multiple machines)
Demo
Here's a simple example where we "train" a model on two nodes (with 2 GPUs each).
Training code
import os
import torch
def train():
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
model = torch.nn.Linear(10, 10).to(local_rank)
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
optimizer = torch.optim.AdamW(ddp_model.parameters())
optimizer.zero_grad()
outputs = ddp_model(torch.randn(5, 10))
labels = torch.randn(5, 10).to(local_rank)
torch.nn.functional.mse_loss(outputs, labels).backward()
optimizer.step()
if rank == 0:
return model
You could also use transformers.Trainer
(or similar) to automatically handle all the multi-GPU / DDP code above.
import torchrunx as trx
if __name__ == "__main__":
result = trx.launch(
func=train,
hostnames=["localhost", "other_node"],
workers_per_host=2 # number of GPUs
)
trained_model = result.rank(0)
torch.save(trained_model.state_dict(), "model.pth")
Full API
Advanced Usage
Why should I use this?
Whether you have 1 GPU, 8 GPUs, or 8 machines:
Features
- Our
launch()
utility is super Pythonic- Return objects from your workers
- Run
python script.py
instead oftorchrun script.py
- Launch multi-node functions, even from Python Notebooks
- Fine-grained control over logging, environment variables, exception handling, etc.
- Automatic integration with SLURM
Robustness
- If you want to run a complex, modular workflow in one script
- don't parallelize your entire script: just the functions you want!
- no worries about memory leaks or OS failures
Convenience
- If you don't want to:
- set up
dist.init_process_group
yourself - manually SSH into every machine and
torchrun --master-ip --master-port ...
, babysit failed processes, etc.
- set up
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchrunx-0.2.3.tar.gz
(48.2 kB
view details)
Built Distribution
torchrunx-0.2.3-py3-none-any.whl
(18.0 kB
view details)
File details
Details for the file torchrunx-0.2.3.tar.gz
.
File metadata
- Download URL: torchrunx-0.2.3.tar.gz
- Upload date:
- Size: 48.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 735ccdfcde8418f7f7bda8ca70490be231072a66a5a3ab297d6410b6df27ee63 |
|
MD5 | 1f15a8ab0d7935d655a8bd24d9ca4cdb |
|
BLAKE2b-256 | d12a1b7f0465d887c394d3e44944df80fb3648a9d9a54b51f80f56adbf7ceacb |
File details
Details for the file torchrunx-0.2.3-py3-none-any.whl
.
File metadata
- Download URL: torchrunx-0.2.3-py3-none-any.whl
- Upload date:
- Size: 18.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f9a3960ddef1d91eca26c68497cd1b34f71aece16e7f3145408ca5e2f4272171 |
|
MD5 | 0e11ef99229d686765afdafd005871c4 |
|
BLAKE2b-256 | b692464917a188b1bb69ea06160fb632910ab82a6ebd70da8335f9cc589126cc |