Skip to main content

Automatically initialize distributed PyTorch environments

Project description

torchrunx 🔥

Python Version PyTorch Version PyPI - Version Documentation Tests GitHub License

By Apoorv Khandelwal and Peter Curtin

The easiest way to run PyTorch on multiple GPUs or machines.


torchrunx is a functional utility for distributing PyTorch code across devices. This is a more convenient, robust, and featureful alternative to CLI-based launchers, like torchrun, accelerate launch, and deepspeed.

It enables complex workflows within a single script and has useful features even if only using 1 GPU.

pip install torchrunx

Requires: Linux. If using multiple machines: SSH & shared filesystem.


Example: simple training loop

Suppose we have some distributed training function (needs to run on every GPU):

def distributed_training(output_dir: str, num_steps: int = 10) -> str:
    # returns path to model checkpoint
Click to expand (implementation)
from __future__ import annotations
import os
import torch
import torch.nn as nn

def distributed_training(output_dir: str, num_steps: int = 10) -> str | None:
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])

    model = nn.Linear(10, 10)
    model.to(local_rank)
    ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    optimizer = torch.optim.AdamW(ddp_model.parameters())

    for step in range(num_steps):
        optimizer.zero_grad()

        inputs = torch.randn(5, 10).to(local_rank)
        labels = torch.randn(5, 10).to(local_rank)
        outputs = ddp_model(inputs)

        torch.nn.functional.mse_loss(outputs, labels).backward()
        optimizer.step()

    if rank == 0:
        os.makedirs(output_dir, exist_ok=True)
        checkpoint_path = os.path.join(output_dir, "model.pt")
        torch.save(model, checkpoint_path)
        return checkpoint_path

    return None

We can distribute and run this function (e.g. on 2 machines x 2 GPUs) using torchrunx!

import logging
import torchrunx

logging.basicConfig(level=logging.INFO)

launcher = torchrunx.Launcher(
    hostnames = ["localhost", "second_machine"],  # or IP addresses
    workers_per_host = "gpu"  # default, or just: 2
)

results = launcher.run(
    distributed_training,
    output_dir = "outputs",
    num_steps = 10,
)

Once completed, you can retrieve the results and process them as you wish.

checkpoint_path: str = results.rank(0)
                 # or: results.index(hostname="localhost", local_rank=0)

# and continue your script
model = torch.load(checkpoint_path, weights_only=False)
model.eval()

See more examples where we fine-tune LLMs using:

Refer to our API, Features, and Usage for many more capabilities!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchrunx-0.3.4.tar.gz (41.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchrunx-0.3.4-py3-none-any.whl (34.2 kB view details)

Uploaded Python 3

File details

Details for the file torchrunx-0.3.4.tar.gz.

File metadata

  • Download URL: torchrunx-0.3.4.tar.gz
  • Upload date:
  • Size: 41.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.5.29

File hashes

Hashes for torchrunx-0.3.4.tar.gz
Algorithm Hash digest
SHA256 6f2333fa17f7ef1f43f6c65d2b008b8479b29d972a8ed209da613d830dffdc45
MD5 41c3bf6a53fc23edbb6dd627375c0253
BLAKE2b-256 6560ce6ccaf618e56775905e75e3fe7f2c8adfb61916d2946e854850c1d19a0d

See more details on using hashes here.

File details

Details for the file torchrunx-0.3.4-py3-none-any.whl.

File metadata

  • Download URL: torchrunx-0.3.4-py3-none-any.whl
  • Upload date:
  • Size: 34.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.5.29

File hashes

Hashes for torchrunx-0.3.4-py3-none-any.whl
Algorithm Hash digest
SHA256 a157ec139f5a0bdfaa5ece50d987ff0a3212a4657d791367f11fdd383ccdbd5b
MD5 752ab8c1beeee4e8ba18aa885cd30ca4
BLAKE2b-256 2c0d9f5e24043f2562fd4dc15d04614cb5f3b4a1ffe327cb040b9f95b03fa84d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page