Skip to main content

No project description provided

Project description

🚧 PipeGoose: Training any 🤗 transformers in Megatron-LM 3D parallelism and ZeRO-1 out of the box

tests Code style: black Codecov Imports: isort Twitter

pipeline

Honk honk honk! This project is actively under development. Check out my learning progress here.

⚠️ The project is actively under development and not ready for use.

⚠️ The APIs is still a work in progress and could change at any time. None of the public APIs are set in stone until we hit version 0.6.9.

⚠️ Support for hybrid 3D parallelism and distributed optimizer for 🤗 transformers will be available in the upcoming weeks (it's basically done, but it doesn't support 🤗 transformers yet)

⚠️ **This library is underperforming compared to Megatron-LM and DeepSpeed (not even achieving reasonable performance). We're actively pushing it to reach 180 TFLOPs and go beyond Megatron-LM **

from torch.utils.data import DataLoader
+ from torch.utils.data.distributed import DistributedSampler
from torch.optim import SGD
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

+ from pipegoose.distributed import ParallelContext, ParallelMode
+ from pipegoose.nn import DataParallel, TensorParallel

model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token

BATCH_SIZE = 4
+ DATA_PARALLEL_SIZE = 2
+ parallel_context = ParallelContext.from_torch(
+    tensor_parallel_size=2,
+    data_parallel_size=2,
+    pipeline_parallel_size=1
+ )
+ model = TensorParallel(model, parallel_context).parallelize()
+ model = DataParallel(model, parallel_context).parallelize()
model.to("cuda")
+ device = next(model.parameters()).device

optim = SGD(model.parameters(), lr=1e-3)

dataset = load_dataset("imdb", split="train")
+ dp_rank = parallel_context.get_local_rank(ParallelMode.DATA)
+ sampler = DistributedSampler(dataset, num_replicas=DATA_PARALLEL_SIZE, rank=dp_rank, seed=42)
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE // DATA_PARALLEL_SIZE, shuffle=False, sampler=sampler)

for epoch in range(100):
+    sampler.set_epoch(epoch)

    for batch in dataloader:
        inputs = tokenizer(batch["text"], padding=True, truncation=True, max_length=1024, return_tensors="pt")
        inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
        labels = inputs["input_ids"]

        outputs = model(**inputs, labels=labels)

        optim.zero_grad()
        outputs.loss.backward()
        optim.step()

Installation and try it out

You can install the package through the following command:

pip install pipegoose

And try out a hybrid tensor and data parallelism training script.

cd pipegoose/examples
torchrun --standalone --nnodes=1 --nproc-per-node=4 hybrid_parallelism.py

We did a small scale for correctness testing by run comparing the training losses between a paralleized transformers and one kept by default, start at identical checkpoint and training data. We will conduct rigorous large scale convergence and weak scaling law benchmarks against megatron and deepspeed in the near future

  • Data Parallelism [link]
  • Tensor Parallelism
  • Hybrid 2D Parallelism

Features

  • Megatron-style 3D parallelism
  • ZeRO-1: Distributed BF16 Optimizer
  • Highly optimized CUDA kernels port from Megatron-LM, DeepSpeed
  • ...

Implementation Details

  • Supports training transformers model in Megatron 3D parallelism and ZeRO-1 (write from scratch).
  • Implements parallel compute and data transfer using separate CUDA streams.
  • Gradient checkpointing will be implemented by enforcing virtual dependency in the backpropagation graph, ensuring that the activation for gradient checkpoint will be recomputed just in time for each (micro-batch, partition).
  • Custom algorithms for model partitioning with two default partitioning models based on elapsed time and GPU memory consumption per layer.
  • Potential support includes:
    • Callbacks within the pipeline: Callback(function, microbatch_idx, partition_idx) for before and after the forward, backward, and recompute steps (for gradient checkpointing).
    • Mixed precision training.

Appreciation

  • Big thanks to 🤗 Hugging Face for sponsoring this project with 8x A100 GPUs for testing! And Zach Schrier for monthly twitch donations

  • The library's APIs are inspired by OSLO's and ColossalAI's APIs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pipegoose-0.2.0.tar.gz (44.2 kB view details)

Uploaded Source

Built Distribution

pipegoose-0.2.0-py3-none-any.whl (67.0 kB view details)

Uploaded Python 3

File details

Details for the file pipegoose-0.2.0.tar.gz.

File metadata

  • Download URL: pipegoose-0.2.0.tar.gz
  • Upload date:
  • Size: 44.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.9.13 Darwin/23.0.0

File hashes

Hashes for pipegoose-0.2.0.tar.gz
Algorithm Hash digest
SHA256 c8d61d77bfe66ddd66573c69733b9617ebfce4d1d356d9e309b502f156764af3
MD5 04a091d5372da4a8a0de948355f78a9f
BLAKE2b-256 77e60d50659350606c26241466d3ad7ba01e4c1e620dfd2afc6f9b2ade3ef9e0

See more details on using hashes here.

File details

Details for the file pipegoose-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: pipegoose-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 67.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.9.13 Darwin/23.0.0

File hashes

Hashes for pipegoose-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f60414446cbbb78ece443e777e7e8382266377b72171bb96b50cbccde5cfba13
MD5 a95ba5864896e14ee69ff462a96bd4a9
BLAKE2b-256 31f8db4ffb0d55cbf8a2e6e983aefcf9b4c05b7729579afd993f3a8f3b2cf228

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page