GPipe for PyTorch
Project description
A GPipe implementation in PyTorch.
from torchgpipe import GPipe
model = nn.Sequential(a, b, c, d)
model = GPipe(model, balance=[1, 1, 1, 1], chunks=8)
for input in data_loader:
output = model(input)
What is GPipe?
GPipe is a scalable pipeline parallelism library published by Google Brain, which allows efficient training of large, memory-consuming models. According to the paper, GPipe can train a 25x larger model by using 8x devices (TPU), and train a model 3.5x faster by using 4x devices.
GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism
Google trained AmoebaNet-B with 557M parameters over GPipe. This model has achieved 84.3% top-1 and 97.0% top-5 accuracy on ImageNet classification benchmark (the state-of-the-art performance as of May 2019).
Links
Source Code: https://github.com/kakaobrain/torchgpipe
Documentation: https://torchgpipe.readthedocs.io/
Original Paper: https://arxiv.org/abs/1811.06965
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Hashes for torchgpipe-0.0.6-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1b4f1747ed89a327f6741b8c5d177de1c07ea0c80735cc1d4862ff3d308feaf0 |
|
MD5 | 6fd2e3f0d16be7d38b2811d3c6bc5c4d |
|
BLAKE2b-256 | 8ad8cccb6d987421fa0308720a059af79e7e0827a88c16e914df60720b5bb6db |