A package for distributed training & model parallelism using Torch
Project description
Overview
DeTrain is a Python package designed to train AI models using model parallelism methods. This package focuses on pipeline and tensor parallelism.
Installation
You can install DeTrain using pip:
pip install detrain
Usage
Once installed, you can use DeTrain in your Python scripts like this:
import torch.nn as nn
import torch
import time
import os
from detrain.ppl.args_util import get_args
from detrain.ppl.worker import run_worker
from detrain.ppl.dataset_util import get_torchvision_dataset
from shards_model import NNShard1, NNShard2
import torch.optim as optim
if __name__=="__main__":
args = get_args()
# Get args
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
epochs = int(args.epochs)
batch_size = int(args.batch_size)
lr = float(args.lr)
for i in range(torch.cuda.device_count()):
print(torch.cuda.get_device_properties(i).name)
devices = []
workers = []
shards = [NNShard1, NNShard2]
# Check devices
if (args.gpu is not None):
arr = args.gpu.split('_')
for dv in range(len(arr)):
if dv > 0:
workers.append(f"worker{dv}")
if int(arr[dv]) == 1:
devices.append("cuda:0")
else:
devices.append("cpu")
# Define optimizer & loss_fn
loss_fn = nn.CrossEntropyLoss()
optimizer_class = optim.SGD
# Dataloaders
(train_dataloader, test_dataloader) = get_torchvision_dataset("MNIST", batch_size)
print(f"World_size: {world_size}, Rank: {rank}")
num_split = 4
tik = time.time()
run_worker(
rank,
world_size,
(
args.split_size,
workers,
devices,
shards
),
train_dataloader,
test_dataloader,
loss_fn,
optimizer_class,
epochs,
batch_size,
lr
)
tok = time.time()
print(f"number of splits = {num_split}, execution time = {tok - tik}")
For detailed examples, please visit the DeTrain examples.
Contributing
Contributions are welcome! If you’d like to contribute to DeTrain, please follow these steps:
Fork the repository on GitHub.
Create a new branch.
Make your changes and commit them with clear descriptions.
Push your changes to your fork.
Submit a pull request.
Bug Reports and Feedback
If you encounter any bugs or have feedback, please open an issue on the GitHub repository.
License
DeTrain is licensed under the MIT License. See the LICENSE file for more information.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file detrain-0.2.6.tar.gz
.
File metadata
- Download URL: detrain-0.2.6.tar.gz
- Upload date:
- Size: 15.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6c88ffc326013aeeaa675aee365af335be7ee8f50799e7eeb66acf13e8b95df2 |
|
MD5 | 498c36dd6f7947d8cd237529abe526cb |
|
BLAKE2b-256 | a79153324575d8608d11920d0b4e647c31ae16c51540af4387973771140dc81e |
File details
Details for the file detrain-0.2.6-py3-none-any.whl
.
File metadata
- Download URL: detrain-0.2.6-py3-none-any.whl
- Upload date:
- Size: 16.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 651bd12301a8011746797eef64a7a3a9c838f0d44b622f5e7aec8cdb0415a40e |
|
MD5 | 210412fc3e9591bf883d090759f75b4a |
|
BLAKE2b-256 | 75c648a5fbef3722ceeb979b0c6991e3744540cbd94d4d3e05c1c0ea0737d1f1 |