Skip to main content

Automatically initialize distributed PyTorch environments

Project description

torchrunx 🔥

PyPI - Python Version PyTorch Version PyPI - Version Tests Docs GitHub License

By Apoorv Khandelwal and Peter Curtin

Automatically distribute PyTorch functions onto multiple machines or GPUs

Installation

pip install torchrunx

Requires: Linux (with shared filesystem & SSH access if using multiple machines)

Demo

Here's a simple example where we "train" a model on two nodes (with 2 GPUs each).

Training code
import os
import torch

def train():
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])

    model = torch.nn.Linear(10, 10).to(local_rank)
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    optimizer = torch.optim.AdamW(ddp_model.parameters())

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(5, 10))
    labels = torch.randn(5, 10).to(local_rank)
    torch.nn.functional.mse_loss(outputs, labels).backward()
    optimizer.step()

    if rank == 0:
        return model

You could also use transformers.Trainer (or similar) to automatically handle all the multi-GPU / DDP code above.

import torchrunx as trx

if __name__ == "__main__":
    result = trx.launch(
        func=train,
        hostnames=["localhost", "other_node"],
        workers_per_host=2  # number of GPUs
    )

    trained_model = result.rank(0)
    torch.save(trained_model.state_dict(), "model.pth")

Full API

Advanced Usage

Why should I use this?

Whether you have 1 GPU, 8 GPUs, or 8 machines:

Features

  • Our launch() utility is super Pythonic
    • Return objects from your workers
    • Run python script.py instead of torchrun script.py
    • Launch multi-node functions, even from Python Notebooks
  • Fine-grained control over logging, environment variables, exception handling, etc.
  • Automatic integration with SLURM

Robustness

  • If you want to run a complex, modular workflow in one script
    • don't parallelize your entire script: just the functions you want!
    • no worries about memory leaks or OS failures

Convenience

  • If you don't want to:
    • set up dist.init_process_group yourself
    • manually SSH into every machine and torchrun --master-ip --master-port ..., babysit failed processes, etc.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchrunx-0.2.1.tar.gz (48.2 kB view details)

Uploaded Source

Built Distribution

torchrunx-0.2.1-py3-none-any.whl (17.9 kB view details)

Uploaded Python 3

File details

Details for the file torchrunx-0.2.1.tar.gz.

File metadata

  • Download URL: torchrunx-0.2.1.tar.gz
  • Upload date:
  • Size: 48.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.0

File hashes

Hashes for torchrunx-0.2.1.tar.gz
Algorithm Hash digest
SHA256 0f35c4f337c9ead5a52a44f51bedf2518f1bdc9f424571133779d19f99a3c0f7
MD5 ab2b16fa437a67d131b8a53209e1c231
BLAKE2b-256 e54fe894bce61502a758f8141587cf172a60584ca0caa8e88e0c4bf3f1f13d8e

See more details on using hashes here.

File details

Details for the file torchrunx-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: torchrunx-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 17.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.0

File hashes

Hashes for torchrunx-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 99580e69635ac96e8523f094070e4e673cd1945b7a6050645eeb356fb58874fc
MD5 4e43fd5c2755b8011eed57a02533247e
BLAKE2b-256 8ae1cecb3ca72abe233640f5f78b548a821cd396253b65cb9ed23571e4f65bda

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page