Skip to main content

Automatically initialize distributed PyTorch environments

Project description

torchrunx 🔥

PyPI - Python Version PyTorch Version PyPI - Version Tests Docs GitHub License

By Apoorv Khandelwal and Peter Curtin

Automatically distribute PyTorch functions onto multiple machines or GPUs

Installation

pip install torchrunx

Requires: Linux (with shared filesystem & SSH access if using multiple machines)

Demo

Here's a simple example where we "train" a model on two nodes (with 2 GPUs each).

Training code
import os
import torch

def train():
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])

    model = torch.nn.Linear(10, 10).to(local_rank)
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    optimizer = torch.optim.AdamW(ddp_model.parameters())

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(5, 10))
    labels = torch.randn(5, 10).to(local_rank)
    torch.nn.functional.mse_loss(outputs, labels).backward()
    optimizer.step()

    if rank == 0:
        return model

You could also use transformers.Trainer (or similar) to automatically handle all the multi-GPU / DDP code above.

import torchrunx as trx

if __name__ == "__main__":
    result = trx.launch(
        func=train,
        hostnames=["localhost", "other_node"],
        workers_per_host=2  # number of GPUs
    )

    trained_model = result.rank(0)
    torch.save(trained_model.state_dict(), "model.pth")

Full API

Advanced Usage

Why should I use this?

Whether you have 1 GPU, 8 GPUs, or 8 machines:

Features

  • Our launch() utility is super Pythonic
    • Return objects from your workers
    • Run python script.py instead of torchrun script.py
    • Launch multi-node functions, even from Python Notebooks
  • Fine-grained control over logging, environment variables, exception handling, etc.
  • Automatic integration with SLURM

Robustness

  • If you want to run a complex, modular workflow in one script
    • don't parallelize your entire script: just the functions you want!
    • no worries about memory leaks or OS failures

Convenience

  • If you don't want to:
    • set up dist.init_process_group yourself
    • manually SSH into every machine and torchrun --master-ip --master-port ..., babysit failed processes, etc.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchrunx-0.2.2.tar.gz (48.2 kB view details)

Uploaded Source

Built Distribution

torchrunx-0.2.2-py3-none-any.whl (17.9 kB view details)

Uploaded Python 3

File details

Details for the file torchrunx-0.2.2.tar.gz.

File metadata

  • Download URL: torchrunx-0.2.2.tar.gz
  • Upload date:
  • Size: 48.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.0

File hashes

Hashes for torchrunx-0.2.2.tar.gz
Algorithm Hash digest
SHA256 d16d7c05e50633c7afec729782d7e7d875d5f1418dd14383df5f34516facd971
MD5 6bf84754e216930f6691c8971b5150a2
BLAKE2b-256 1c9814d118b3c3c62c19ac33646b2204e955573ea1069c98c113f149ebf28865

See more details on using hashes here.

File details

Details for the file torchrunx-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: torchrunx-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 17.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.0

File hashes

Hashes for torchrunx-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2bce6f297eca135c31580cbf5586678c7b2147375a880873a4484eeb682f7248
MD5 91d000af683b0608ad1e850718013d6a
BLAKE2b-256 866840db980d9108f460b5bedd34e4b6140c8b9157b939368b667213093f42b6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page