Skip to main content

Automatically initialize distributed PyTorch environments

Project description

torchrunx 🔥

Python Version PyTorch Version PyPI - Version Documentation Tests GitHub License

By Apoorv Khandelwal and Peter Curtin

The easiest way to run PyTorch on multiple GPUs or machines.


torchrunx is a functional utility for distributing PyTorch code across devices. This is a more convenient, robust, and featureful alternative to CLI-based launchers, like torchrun, accelerate launch, and deepspeed.

It enables complex workflows within a single script and has useful features even if only using 1 GPU.

pip install torchrunx

Requires: Linux. If using multiple machines: SSH & shared filesystem.


Example: simple training loop

Suppose we have some distributed training function (which needs to run on every GPU):

def distributed_training(model: nn.Module, num_steps: int) -> nn.Module: ...
Implementation of distributed_training (click to expand)
from __future__ import annotations
import os
import torch
import torch.nn as nn

def distributed_training(model: nn.Module, num_steps: int = 10) -> nn.Module | None:
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])

    model.to(local_rank)
    ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    optimizer = torch.optim.AdamW(ddp_model.parameters())

    for step in range(num_steps):
        optimizer.zero_grad()

        inputs = torch.randn(5, 10).to(local_rank)
        labels = torch.randn(5, 10).to(local_rank)
        outputs = ddp_model(inputs)

        torch.nn.functional.mse_loss(outputs, labels).backward()
        optimizer.step()

    if rank == 0:
        return model.cpu()

We can distribute and run this function (e.g. on 2 machines x 2 GPUs) using torchrunx!

import logging
logging.basicConfig(level=logging.INFO)

import torchrunx

launcher = torchrunx.Launcher(
    hostnames = ["localhost", "second_machine"],  # or IP addresses
    workers_per_host = 2  # e.g. number of GPUs per host
)

results = launcher.run(
    distributed_training,
    model = nn.Linear(10, 10),
    num_steps = 10
)

Once completed, you can retrieve the results and process them as you wish.

trained_model: nn.Module = results.rank(0)
                     # or: results.index(hostname="localhost", local_rank=0)

# and continue your script
torch.save(trained_model.state_dict(), "output/model.pth")

See more examples where we fine-tune LLMs using:

Refer to our API and Usage for many more capabilities!


torchrunx uniquely offers

  1. An automatic launcher that "just works" for everyone 🚀

torchrunx is an SSH-based, pure-Python library that is universally easy to install.
No system-specific dependencies and orchestration for automatic multi-node distribution.

  1. Conventional CLI commands 🖥️

Run familiar commands, like python my_script.py ..., and customize arguments as you wish.

Other launchers override python in a cumbersome way: e.g. torchrun --nproc_per_node=2 --nnodes=2 --node_rank=0 --master_addr=100.43.331.111 --master_port=1234 my_script.py ....

  1. Support for more complex workflows in a single script 🎛️

Your workflow may have steps that are complex (e.g. pre-train, fine-tune, test) or may different parallelizations (e.g. training on 8 GPUs, testing on 1 GPU). In these cases, CLI-based launchers require each step to live in its own script. Our library treats these steps in a modular way, so they can cleanly fit together in a single script!

We clean memory leaks as we go, so previous steps won't crash or adversely affect future steps.

  1. Better handling of system failures. No more zombies! 🧟

With torchrun, your "work" is inherently coupled to your main Python process. If the system kills one of your workers (e.g. due to RAM OOM or segmentation faults), there is no way to fail gracefully in Python. Your processes might hang for 10 minutes (the NCCL timeout) or become perpetual zombies.

torchrunx decouples "launcher" and "worker" processes. If the system kills a worker, our launcher immediately raises a WorkerFailure exception, which users can handle as they wish. We always clean up all nodes, so no more zombies!

  1. Bonus features 🎁

On our roadmap: higher-order parallelism, support for debuggers, and more!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchrunx-0.3.1.tar.gz (379.7 kB view details)

Uploaded Source

Built Distribution

torchrunx-0.3.1-py3-none-any.whl (46.7 kB view details)

Uploaded Python 3

File details

Details for the file torchrunx-0.3.1.tar.gz.

File metadata

  • Download URL: torchrunx-0.3.1.tar.gz
  • Upload date:
  • Size: 379.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.5.29

File hashes

Hashes for torchrunx-0.3.1.tar.gz
Algorithm Hash digest
SHA256 b9ee406ead0954ca36634bb20c9d268891d3bb41d01c830a32b3f76bf4235732
MD5 d3ac9c30f05cd1c8bcca8a4ca41b5f7f
BLAKE2b-256 26bed505087af27158310ac907838cb929f50ea1276aecd4e89a47e3cd902919

See more details on using hashes here.

File details

Details for the file torchrunx-0.3.1-py3-none-any.whl.

File metadata

File hashes

Hashes for torchrunx-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b0b279becd58755f0c39cc947ed6bfe08f31b7eb6c73de5f8b7f7b8259d3f6ba
MD5 8532d36e6f7992c2f1a98d5be1eb5f17
BLAKE2b-256 5e3675ed8d8f25908a6f2f6b092952a22b2b9ea6cb418f263de5be4ef4376aa3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page