Automatically initialize distributed PyTorch environments
Project description
torchrunx 🔥
By Apoorv Khandelwal and Peter Curtin
Automatically distribute PyTorch functions onto multiple machines or GPUs
Installation
pip install torchrunx
Requires: Linux (with shared filesystem & SSH access if using multiple machines)
Demo
Here's a simple example where we "train" a model on two nodes (with 2 GPUs each).
Training code
import os
import torch
def train():
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
model = torch.nn.Linear(10, 10).to(local_rank)
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
optimizer = torch.optim.AdamW(ddp_model.parameters())
optimizer.zero_grad()
outputs = ddp_model(torch.randn(5, 10))
labels = torch.randn(5, 10).to(local_rank)
torch.nn.functional.mse_loss(outputs, labels).backward()
optimizer.step()
if rank == 0:
return model
You could also use transformers.Trainer
(or similar) to automatically handle all the multi-GPU / DDP code above.
import torchrunx as trx
if __name__ == "__main__":
result = trx.launch(
func=train,
hostnames=["localhost", "other_node"],
workers_per_host=2 # number of GPUs
)
trained_model = result.rank(0)
torch.save(trained_model.state_dict(), "model.pth")
Full API
Advanced Usage
Why should I use this?
Whether you have 1 GPU, 8 GPUs, or 8 machines:
Features
- Our
launch()
utility is super Pythonic- Return objects from your workers
- Run
python script.py
instead oftorchrun script.py
- Launch multi-node functions, even from Python Notebooks
- Fine-grained control over logging, environment variables, exception handling, etc.
- Automatic integration with SLURM
Robustness
- If you want to run a complex, modular workflow in one script
- don't parallelize your entire script: just the functions you want!
- no worries about memory leaks or OS failures
Convenience
- If you don't want to:
- set up
dist.init_process_group
yourself - manually SSH into every machine and
torchrun --master-ip --master-port ...
, babysit failed processes, etc.
- set up
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchrunx-0.2.4.tar.gz
(101.7 kB
view details)
Built Distribution
torchrunx-0.2.4-py3-none-any.whl
(18.0 kB
view details)
File details
Details for the file torchrunx-0.2.4.tar.gz
.
File metadata
- Download URL: torchrunx-0.2.4.tar.gz
- Upload date:
- Size: 101.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.5.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3fc35e4c11d9de16e47603ec0966f0f2fec9b6091ab66f3a13349b480ac6da84 |
|
MD5 | 2d75da4913f8ffc4d42bc0630913d27f |
|
BLAKE2b-256 | b40538e3cabd81c92c67d3021acd17832909d867180387d05b575d67f7f3f111 |
File details
Details for the file torchrunx-0.2.4-py3-none-any.whl
.
File metadata
- Download URL: torchrunx-0.2.4-py3-none-any.whl
- Upload date:
- Size: 18.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.5.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d684b1db329e5ed009b1f58b437f58f6714a7bebd10f748f11f23989c2784436 |
|
MD5 | aae64770c0a1aafd8de8d8f5c0ceaf09 |
|
BLAKE2b-256 | 6a82b5d8c8cde4b6b7ffb9d08d4511cc512999c61041086f0cf0b796f51efc30 |