Implementation of Google Research's "RigL" sparse model training method in PyTorch.
Project description
rigl-torch
Warning: This repository is still in active development, results are not yet up to the rigl paper spec. Coming soon!
An open source implementation of Google Research's paper (Authored by Utku Evci, an AI Resident @ Google Brain): Rigging the Lottery: Making All Tickets Winners (RigL) in PyTorch as versatile, simple, and fast as possible.
You only need to add 2 lines of code to your PyTorch project to use RigL to train your model with sparsity!
ImageNet Results
Results aren't quite as complete as the original paper, however this implementation outperforms the original by 1%!
Architecture | Sparsity % | S. Distribution | Top-1 | Original Top-1 |
---|---|---|---|---|
ResNet50 | 90% | Uniform | 73% | 72% |
Other Implementations:
- View the TensorFlow implementation (also the original) here!
- Additionally, it is also implemented in vanilla python and graphcore.
Contributions Beyond the Paper:
Gradient Accumulation:
Motivation:
- The paper cites their experiments for ImageNet being done using a batch size of 4096, which isn't practical for everyone since to do so you need 32 Tesla V100s to store that many 224x224x3 images in VRAM.
- Following this, if you are using a significantly small batch size for training (ie. bs=1024 for ImageNet), RigL may perform suboptimally due to instantaneous gradient information being quite noisy. To remedy this, I have introduced a solution for "emulating" larger batch sizes for topology modifications.
Method:
- In regular dense training gradients are calculated per batch essentially averaging the loss of each sample and taking the derivative w.r.t the parameters. This means that if your batch size is 1024, the gradients are the accumulated average over 1024 data samples.
- In normal RigL the grow/drop perturbations scoring is being done on 1 batch (every
delta
batches, typicallydelta=100
) and replaces the backpropagation step for that iteration. So you can see that if the batch size is significantly small, the topology modifications are being done on a very small amount of data, thus missing some potential signal from the dataset. In dense training, this is a balancing act (too large batch sizes have diminishing returns and can harm exploration, making it more likely to fall in a local minimum). - If
gradient_accumulation_n
is > 1, then when RigL wants to make a topology modification it essentially takes not only the current batch's gradients, but also the previousgradient_accumulation_n - 1
batch's gradients. It then averages them element-wise, and uses this new matrix to score the grow/drop perturbations. - Note:
gradient_accumulation_n
has to be within the interval [1,delta
). Ifgradient_accumulation_n
== 1, then nothing has changed from the paper's spec. Ifgradient_accumulation_n
== (delta
- 1), RigL will score based on every single batch from the previous RigL step to the current one.
Results:
- Setting the
gradient_accumulation_n
to a value > 1 increases performance on ImageNet by about 0.3-1% when using a batch size of 1024. In order to get the best results from batch size 1024 (for ImageNet), you should also multiply thedelta
value by 4. This is because with a batch size of 4096, you are doing 4x less RigL steps (4096/1024 = 4) than if you used a batch size of 1024.
User Setup:
pip install rigl-torch
Contributor Setup:
- Clone this repository:
git clone https://github.com/McCrearyD/rigl-torch
- Cd into repo:
cd rigl-torch
- Install dependencies:
pip install -r requirements.txt
- Install package (
-e
allows for modifications):pip install -e .
Usage:
-
Run the tests by doing
cd rigl-torch
, thenpytest
. -
I have provided some examples of training scripts that were slightly modified to add RigL's functionality. It adds a few parser statements, and only 2 required lines of RigL code usage to work! See them with links to the originals here:
ImageNet
| RigL | Original | RigL + SageMakerMNIST
| RigL | Original
-
OR more impressively, you can use the pruning power of RigL by adding 2 lines of code to your already existing training script! Here is how:
from rigl_torch.RigL import RigLScheduler
# first, create your model
model = ... # note: only tested on torch.hub's resnet networks (ie. resnet18 / resnet50)
# create your dataset/dataloader
dataset = ...
dataloader = ...
# define your optimizer (recommended SGD w/ momentum)
optimizer = ...
# RigL runs best when you allow RigL's topology modifications to run for 75% of the total training iterations (batches)
# so, let's calculate T_end according to this
epochs = 100
total_iterations = len(dataloader) * epochs
T_end = int(0.75 * total_iterations)
# ------------------------------------ REQUIRED LINE # 1 ------------------------------------
# now, create the RigLScheduler object
pruner = RigLScheduler(model, # model you created
optimizer, # optimizer (recommended = SGD w/ momentum)
dense_allocation=0.1, # a float between 0 and 1 that designates how sparse you want the network to be
# (0.1 dense_allocation = 90% sparse)
sparsity_distribution='uniform', # distribution hyperparam within the paper, currently only supports `uniform`
T_end=T_end, # T_end hyperparam within the paper (recommended = 75% * total_iterations)
delta=100, # delta hyperparam within the paper (recommended = 100)
alpha=0.3, # alpha hyperparam within the paper (recommended = 0.3)
grad_accumulation_n=1, # new hyperparam contribution (not in the paper)
# for more information, see the `Contributions Beyond the Paper` section
static_topo=False, # if True, the topology will be frozen, in other words RigL will not do it's job
# (for debugging)
ignore_linear_layers=False, # if True, linear layers in the network will be kept fully dense
state_dict=None) # if you have checkpointing enabled for your training script, you should save
# `pruner.state_dict()` and when resuming pass the loaded `state_dict` into
# the pruner constructor
# -------------------------------------------------------------------------------------------
... more code ...
for epoch in range(epochs):
for data in dataloader:
# do forward pass, calculate loss, etc.
...
# instead of calling optimizer.step(), wrap it as such:
# ------------------------------------ REQUIRED LINE # 2 ------------------------------------
if pruner():
# -------------------------------------------------------------------------------------------
# this block of code will execute according to the given hyperparameter schedule
# in other words, optimizer.step() is not called after a RigL step
optimizer.step()
# it is also recommended that after every epoch you checkpoint your training progress
# to do so with RigL training you should also save the pruner object state_dict
torch.save({
'model': model.state_dict(),
'pruner': pruner.state_dict(),
'optimizer': optimizer.state_dict()
}, 'checkpoint.pth')
# at any time you can print the RigLScheduler object and it will show you the sparsity distributions, number of training steps/rigl steps, etc!
print(pruner)
# save model
torch.save(model.state_dict(), 'model.pth')
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file rigl-torch-0.5.2.tar.gz
.
File metadata
- Download URL: rigl-torch-0.5.2.tar.gz
- Upload date:
- Size: 12.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.24.0 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6ac1144c6dd9e04c8527cf0f7335478cabf5126cf1b41aab282ada6027f8697b |
|
MD5 | 567fafe47714d451644c5e929208b1ff |
|
BLAKE2b-256 | 77936c3363657dcf52860f5cfd40f1a94debd466836267d7c622d58cc5c14805 |
File details
Details for the file rigl_torch-0.5.2-py3-none-any.whl
.
File metadata
- Download URL: rigl_torch-0.5.2-py3-none-any.whl
- Upload date:
- Size: 8.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.24.0 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 41b91f378dc1933ff61357e989472c625e5353104d8b00b7a0c9ddd28b5c4bab |
|
MD5 | fd87b3289e58c76de296694318e4ed13 |
|
BLAKE2b-256 | 9e38f87e8892ade00228f9bb12b548dc74560a64f57c9a9b4eedac228d1f9fa6 |