Automatically initialize distributed PyTorch environments
Project description
torchrunx 🔥
By Apoorv Khandelwal and Peter Curtin
The easiest way to run PyTorch on multiple GPUs or machines.
torchrunx is a functional utility for distributing PyTorch code across devices. This is a more convenient, robust, and featureful alternative to CLI-based launchers, like torchrun, accelerate launch, and deepspeed.
It enables complex workflows within a single script and has useful features even if only using 1 GPU.
pip install torchrunx
Requires: Linux. If using multiple machines: SSH & shared filesystem.
Example: simple training loop
Suppose we have some distributed training function (needs to run on every GPU):
def distributed_training(output_dir: str, num_steps: int = 10) -> str:
# returns path to model checkpoint
Click to expand (implementation)
from __future__ import annotations
import os
import torch
import torch.nn as nn
def distributed_training(output_dir: str, num_steps: int = 10) -> str | None:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
model = nn.Linear(10, 10)
model.to(local_rank)
ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
optimizer = torch.optim.AdamW(ddp_model.parameters())
for step in range(num_steps):
optimizer.zero_grad()
inputs = torch.randn(5, 10).to(local_rank)
labels = torch.randn(5, 10).to(local_rank)
outputs = ddp_model(inputs)
torch.nn.functional.mse_loss(outputs, labels).backward()
optimizer.step()
if rank == 0:
os.makedirs(output_dir, exist_ok=True)
checkpoint_path = os.path.join(output_dir, "model.pt")
torch.save(model, checkpoint_path)
return checkpoint_path
return None
We can distribute and run this function (e.g. on 2 machines x 2 GPUs) using torchrunx!
import logging
import torchrunx
logging.basicConfig(level=logging.INFO)
launcher = torchrunx.Launcher(
hostnames = ["localhost", "second_machine"], # or IP addresses
workers_per_host = "gpu" # default, or just: 2
)
results = launcher.run(
distributed_training,
output_dir = "outputs",
num_steps = 10,
)
Once completed, you can retrieve the results and process them as you wish.
checkpoint_path: str = results.rank(0)
# or: results.index(hostname="localhost", local_rank=0)
# and continue your script
model = torch.load(checkpoint_path, weights_only=False)
model.eval()
See more examples where we fine-tune LLMs using:
Refer to our API, Features, and Usage for many more capabilities!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchrunx-0.3.4.tar.gz.
File metadata
- Download URL: torchrunx-0.3.4.tar.gz
- Upload date:
- Size: 41.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.5.29
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6f2333fa17f7ef1f43f6c65d2b008b8479b29d972a8ed209da613d830dffdc45
|
|
| MD5 |
41c3bf6a53fc23edbb6dd627375c0253
|
|
| BLAKE2b-256 |
6560ce6ccaf618e56775905e75e3fe7f2c8adfb61916d2946e854850c1d19a0d
|
File details
Details for the file torchrunx-0.3.4-py3-none-any.whl.
File metadata
- Download URL: torchrunx-0.3.4-py3-none-any.whl
- Upload date:
- Size: 34.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.5.29
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a157ec139f5a0bdfaa5ece50d987ff0a3212a4657d791367f11fdd383ccdbd5b
|
|
| MD5 |
752ab8c1beeee4e8ba18aa885cd30ca4
|
|
| BLAKE2b-256 |
2c0d9f5e24043f2562fd4dc15d04614cb5f3b4a1ffe327cb040b9f95b03fa84d
|