Skip to main content

Automatically initialize distributed PyTorch environments

Project description

torchrunx 🔥

PyPI - Python Version PyTorch Version PyPI - Version Tests Docs GitHub License

By Apoorv Khandelwal and Peter Curtin

Automatically distribute PyTorch functions onto multiple machines or GPUs

Installation

pip install torchrunx

Requires: Linux (with shared filesystem & SSH access if using multiple machines)

Demo

Here's a simple example where we "train" a model on two nodes (with 2 GPUs each).

Training code
import os
import torch

def train():
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])

    model = torch.nn.Linear(10, 10).to(local_rank)
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    optimizer = torch.optim.AdamW(ddp_model.parameters())

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(5, 10))
    labels = torch.randn(5, 10).to(local_rank)
    torch.nn.functional.mse_loss(outputs, labels).backward()
    optimizer.step()

    if rank == 0:
        return model

You could also use transformers.Trainer (or similar) to automatically handle all the multi-GPU / DDP code above.

import torchrunx as trx

if __name__ == "__main__":
    result = trx.launch(
        func=train,
        hostnames=["localhost", "other_node"],
        workers_per_host=2  # number of GPUs
    )

    trained_model = result.rank(0)
    torch.save(trained_model.state_dict(), "model.pth")

Full API

Advanced Usage

Why should I use this?

Whether you have 1 GPU, 8 GPUs, or 8 machines:

Features

  • Our launch() utility is super Pythonic
    • Return objects from your workers
    • Run python script.py instead of torchrun script.py
    • Launch multi-node functions, even from Python Notebooks
  • Fine-grained control over logging, environment variables, exception handling, etc.
  • Automatic integration with SLURM

Robustness

  • If you want to run a complex, modular workflow in one script
    • don't parallelize your entire script: just the functions you want!
    • no worries about memory leaks or OS failures

Convenience

  • If you don't want to:
    • set up dist.init_process_group yourself
    • manually SSH into every machine and torchrun --master-ip --master-port ..., babysit failed processes, etc.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchrunx-0.2.4.tar.gz (101.7 kB view details)

Uploaded Source

Built Distribution

torchrunx-0.2.4-py3-none-any.whl (18.0 kB view details)

Uploaded Python 3

File details

Details for the file torchrunx-0.2.4.tar.gz.

File metadata

  • Download URL: torchrunx-0.2.4.tar.gz
  • Upload date:
  • Size: 101.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.5.0

File hashes

Hashes for torchrunx-0.2.4.tar.gz
Algorithm Hash digest
SHA256 3fc35e4c11d9de16e47603ec0966f0f2fec9b6091ab66f3a13349b480ac6da84
MD5 2d75da4913f8ffc4d42bc0630913d27f
BLAKE2b-256 b40538e3cabd81c92c67d3021acd17832909d867180387d05b575d67f7f3f111

See more details on using hashes here.

File details

Details for the file torchrunx-0.2.4-py3-none-any.whl.

File metadata

File hashes

Hashes for torchrunx-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 d684b1db329e5ed009b1f58b437f58f6714a7bebd10f748f11f23989c2784436
MD5 aae64770c0a1aafd8de8d8f5c0ceaf09
BLAKE2b-256 6a82b5d8c8cde4b6b7ffb9d08d4511cc512999c61041086f0cf0b796f51efc30

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page