Skip to main content

Automatically initialize distributed PyTorch environments

Project description

torchrunx 🔥

PyPI - Python Version PyTorch Version PyPI - Version Tests Docs GitHub License

By Apoorv Khandelwal and Peter Curtin

Automatically distribute PyTorch functions onto multiple machines or GPUs

Installation

pip install torchrunx

Requires: Linux (with shared filesystem & SSH access if using multiple machines)

Demo

Here's a simple example where we "train" a model on two nodes (with 2 GPUs each).

Training code
import os
import torch

def train():
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])

    model = torch.nn.Linear(10, 10).to(local_rank)
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    optimizer = torch.optim.AdamW(ddp_model.parameters())

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(5, 10))
    labels = torch.randn(5, 10).to(local_rank)
    torch.nn.functional.mse_loss(outputs, labels).backward()
    optimizer.step()

    if rank == 0:
        return model

You could also use transformers.Trainer (or similar) to automatically handle all the multi-GPU / DDP code above.

import torchrunx as trx

if __name__ == "__main__":
    result = trx.launch(
        func=train,
        hostnames=["localhost", "other_node"],
        workers_per_host=2  # number of GPUs
    )

    trained_model = result.rank(0)
    torch.save(trained_model.state_dict(), "model.pth")

Full API

Advanced Usage

Why should I use this?

Whether you have 1 GPU, 8 GPUs, or 8 machines:

Features

  • Our launch() utility is super Pythonic
    • Return objects from your workers
    • Run python script.py instead of torchrun script.py
    • Launch multi-node functions, even from Python Notebooks
  • Fine-grained control over logging, environment variables, exception handling, etc.
  • Automatic integration with SLURM

Robustness

  • If you want to run a complex, modular workflow in one script
    • don't parallelize your entire script: just the functions you want!
    • no worries about memory leaks or OS failures

Convenience

  • If you don't want to:
    • set up dist.init_process_group yourself
    • manually SSH into every machine and torchrun --master-ip --master-port ..., babysit failed processes, etc.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchrunx-0.2.0.tar.gz (47.8 kB view details)

Uploaded Source

Built Distribution

torchrunx-0.2.0-py3-none-any.whl (17.5 kB view details)

Uploaded Python 3

File details

Details for the file torchrunx-0.2.0.tar.gz.

File metadata

  • Download URL: torchrunx-0.2.0.tar.gz
  • Upload date:
  • Size: 47.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.0

File hashes

Hashes for torchrunx-0.2.0.tar.gz
Algorithm Hash digest
SHA256 9838e74a2c6eea8b5446428c58baf804a9f8128558f9969b53a637948fd2666a
MD5 a4939f5d124edb54d6cfccf230612266
BLAKE2b-256 aec2843c9430f9abf805a9136179e890f1a65196b730a1a6db65588335250f99

See more details on using hashes here.

File details

Details for the file torchrunx-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: torchrunx-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 17.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.0

File hashes

Hashes for torchrunx-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 69fbf0fc600868c6eb778c6cfbdbc90231a11f95a01a0d8fbefb0d30ee84dc7a
MD5 e90c4a9cab7391adbef7df560aeeff4d
BLAKE2b-256 f84425413b79696d5d95864df86d3ed7d229b8578c27ed714d6384172bf757bc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page