Skip to main content

Automatically initialize distributed PyTorch environments

Project description

torchrunx 🔥

PyPI - Python Version PyTorch Version PyPI - Version Tests Docs GitHub License

By Apoorv Khandelwal and Peter Curtin

Automatically distribute PyTorch functions onto multiple machines or GPUs

Installation

pip install torchrunx

Requires: Linux (with shared filesystem & SSH access if using multiple machines)

Demo

Here's a simple example where we "train" a model on two nodes (with 2 GPUs each).

Training code
import os
import torch

def train():
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])

    model = torch.nn.Linear(10, 10).to(local_rank)
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    optimizer = torch.optim.AdamW(ddp_model.parameters())

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(5, 10))
    labels = torch.randn(5, 10).to(local_rank)
    torch.nn.functional.mse_loss(outputs, labels).backward()
    optimizer.step()

    if rank == 0:
        return model

You could also use transformers.Trainer (or similar) to automatically handle all the multi-GPU / DDP code above.

import torchrunx as trx

if __name__ == "__main__":
    result = trx.launch(
        func=train,
        hostnames=["localhost", "other_node"],
        workers_per_host=2  # number of GPUs
    )

    trained_model = result.rank(0)
    torch.save(trained_model.state_dict(), "model.pth")

Full API

Advanced Usage

Why should I use this?

Whether you have 1 GPU, 8 GPUs, or 8 machines:

Features

  • Our launch() utility is super Pythonic
    • Return objects from your workers
    • Run python script.py instead of torchrun script.py
    • Launch multi-node functions, even from Python Notebooks
  • Fine-grained control over logging, environment variables, exception handling, etc.
  • Automatic integration with SLURM

Robustness

  • If you want to run a complex, modular workflow in one script
    • don't parallelize your entire script: just the functions you want!
    • no worries about memory leaks or OS failures

Convenience

  • If you don't want to:
    • set up dist.init_process_group yourself
    • manually SSH into every machine and torchrun --master-ip --master-port ..., babysit failed processes, etc.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchrunx-0.2.3.tar.gz (48.2 kB view details)

Uploaded Source

Built Distribution

torchrunx-0.2.3-py3-none-any.whl (18.0 kB view details)

Uploaded Python 3

File details

Details for the file torchrunx-0.2.3.tar.gz.

File metadata

  • Download URL: torchrunx-0.2.3.tar.gz
  • Upload date:
  • Size: 48.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.0

File hashes

Hashes for torchrunx-0.2.3.tar.gz
Algorithm Hash digest
SHA256 735ccdfcde8418f7f7bda8ca70490be231072a66a5a3ab297d6410b6df27ee63
MD5 1f15a8ab0d7935d655a8bd24d9ca4cdb
BLAKE2b-256 d12a1b7f0465d887c394d3e44944df80fb3648a9d9a54b51f80f56adbf7ceacb

See more details on using hashes here.

File details

Details for the file torchrunx-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: torchrunx-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 18.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.0

File hashes

Hashes for torchrunx-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 f9a3960ddef1d91eca26c68497cd1b34f71aece16e7f3145408ca5e2f4272171
MD5 0e11ef99229d686765afdafd005871c4
BLAKE2b-256 b692464917a188b1bb69ea06160fb632910ab82a6ebd70da8335f9cc589126cc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page