GPipe for PyTorch
Project description
A GPipe implementation in PyTorch.
from torchgpipe import GPipe
model = nn.Sequential(a, b, c, d)
model = GPipe(model, balance=[1, 1, 1, 1], chunks=8)
for input in data_loader:
output = model(input)
What is GPipe?
GPipe is a scalable pipeline parallelism library published by Google Brain, which allows efficient training of large, memory-consuming models. According to the paper, GPipe can train a 25x larger model by using 8x devices (TPU), and train a model 3.5x faster by using 4x devices.
GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism
Google trained AmoebaNet-B with 557M parameters over GPipe. This model has achieved 84.3% top-1 and 97.0% top-5 accuracy on ImageNet classification benchmark (the state-of-the-art performance as of May 2019).
Links
Source Code: https://github.com/kakaobrain/torchgpipe
Documentation: https://torchgpipe.readthedocs.io/
Original Paper: https://arxiv.org/abs/1811.06965
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.