Skip to main content

Implementation of Google Research's "RigL" sparse model training method in PyTorch.

Project description

rigl-torch

PyTest Upload Python Package

Warning: This repository is still in active development, results are not yet up to the rigl paper spec. Coming soon!

An open source implementation of Google Research's paper (Authored by Utku Evci, an AI Resident @ Google Brain): Rigging the Lottery: Making All Tickets Winners (RigL) in PyTorch as versatile, simple, and fast as possible.

You only need to add 2 lines of code to your PyTorch project to use RigL to train your model with sparsity!

ImageNet Results

Results aren't quite as complete as the original paper, however this implementation outperforms the original by 1%!

Architecture Sparsity % S. Distribution Top-1 Original Top-1
ResNet50 90% Uniform 73% 72%

Other Implementations:

Contributions Beyond the Paper:

Gradient Accumulation:

Motivation:

  • The paper cites their experiments for ImageNet being done using a batch size of 4096, which isn't practical for everyone since to do so you need 32 Tesla V100s to store that many 224x224x3 images in VRAM.
  • Following this, if you are using a significantly small batch size for training (ie. bs=1024 for ImageNet), RigL may perform suboptimally due to instantaneous gradient information being quite noisy. To remedy this, I have introduced a solution for "emulating" larger batch sizes for topology modifications.

Method:

  • In regular dense training gradients are calculated per batch essentially averaging the loss of each sample and taking the derivative w.r.t the parameters. This means that if your batch size is 1024, the gradients are the accumulated average over 1024 data samples.
  • In normal RigL the grow/drop perturbations scoring is being done on 1 batch (every delta batches, typically delta=100) and replaces the backpropagation step for that iteration. So you can see that if the batch size is significantly small, the topology modifications are being done on a very small amount of data, thus missing some potential signal from the dataset. In dense training, this is a balancing act (too large batch sizes have diminishing returns and can harm exploration, making it more likely to fall in a local minimum).
  • If gradient_accumulation_n is > 1, then when RigL wants to make a topology modification it essentially takes not only the current batch's gradients, but also the previous gradient_accumulation_n - 1 batch's gradients. It then averages them element-wise, and uses this new matrix to score the grow/drop perturbations.
  • Note: gradient_accumulation_n has to be within the interval [1, delta). If gradient_accumulation_n == 1, then nothing has changed from the paper's spec. If gradient_accumulation_n == (delta - 1), RigL will score based on every single batch from the previous RigL step to the current one.

Results:

  • Setting the gradient_accumulation_n to a value > 1 increases performance on ImageNet by about 0.3-1% when using a batch size of 1024. In order to get the best results from batch size 1024 (for ImageNet), you should also multiply the delta value by 4. This is because with a batch size of 4096, you are doing 4x less RigL steps (4096/1024 = 4) than if you used a batch size of 1024.

User Setup:

  • pip install rigl-torch

Contributor Setup:

  • Clone this repository: git clone https://github.com/McCrearyD/rigl-torch
  • Cd into repo: cd rigl-torch
  • Install dependencies: pip install -r requirements.txt
  • Install package (-e allows for modifications): pip install -e .

Usage:

  • Run the tests by doing cd rigl-torch, then pytest.

  • I have provided some examples of training scripts that were slightly modified to add RigL's functionality. It adds a few parser statements, and only 2 required lines of RigL code usage to work! See them with links to the originals here:

  • OR more impressively, you can use the pruning power of RigL by adding 2 lines of code to your already existing training script! Here is how:

from rigl_torch.RigL import RigLScheduler

# first, create your model
model = ... # note: only tested on torch.hub's resnet networks (ie. resnet18 / resnet50)

# create your dataset/dataloader
dataset = ...
dataloader = ...

# define your optimizer (recommended SGD w/ momentum)
optimizer = ...


# RigL runs best when you allow RigL's topology modifications to run for 75% of the total training iterations (batches)
# so, let's calculate T_end according to this
epochs = 100
total_iterations = len(dataloader) * epochs
T_end = int(0.75 * total_iterations)

# ------------------------------------ REQUIRED LINE # 1 ------------------------------------
# now, create the RigLScheduler object
pruner = RigLScheduler(model,                           # model you created
                       optimizer,                       # optimizer (recommended = SGD w/ momentum)
                       dense_allocation=0.1,            # a float between 0 and 1 that designates how sparse you want the network to be 
                                                          # (0.1 dense_allocation = 90% sparse)
                       sparsity_distribution='uniform', # distribution hyperparam within the paper, currently only supports `uniform`
                       T_end=T_end,                     # T_end hyperparam within the paper (recommended = 75% * total_iterations)
                       delta=100,                       # delta hyperparam within the paper (recommended = 100)
                       alpha=0.3,                       # alpha hyperparam within the paper (recommended = 0.3)
                       grad_accumulation_n=1,           # new hyperparam contribution (not in the paper) 
                                                          # for more information, see the `Contributions Beyond the Paper` section
                       static_topo=False,               # if True, the topology will be frozen, in other words RigL will not do it's job 
                                                          # (for debugging)
                       ignore_linear_layers=False,      # if True, linear layers in the network will be kept fully dense
                       state_dict=None)                 # if you have checkpointing enabled for your training script, you should save 
                                                          # `pruner.state_dict()` and when resuming pass the loaded `state_dict` into 
                                                          # the pruner constructor
# -------------------------------------------------------------------------------------------

... more code ...

for epoch in range(epochs):
    for data in dataloader:
        # do forward pass, calculate loss, etc.
        ...

        # instead of calling optimizer.step(), wrap it as such:

# ------------------------------------ REQUIRED LINE # 2 ------------------------------------
        if pruner():
# -------------------------------------------------------------------------------------------
            # this block of code will execute according to the given hyperparameter schedule
            # in other words, optimizer.step() is not called after a RigL step
            optimizer.step()

    # it is also recommended that after every epoch you checkpoint your training progress
    # to do so with RigL training you should also save the pruner object state_dict
    torch.save({
        'model': model.state_dict(),
        'pruner': pruner.state_dict(),
        'optimizer': optimizer.state_dict()
    }, 'checkpoint.pth')

# at any time you can print the RigLScheduler object and it will show you the sparsity distributions, number of training steps/rigl steps, etc!
print(pruner)

# save model
torch.save(model.state_dict(), 'model.pth')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rigl-torch-0.5.2.tar.gz (12.0 kB view details)

Uploaded Source

Built Distribution

rigl_torch-0.5.2-py3-none-any.whl (8.9 kB view details)

Uploaded Python 3

File details

Details for the file rigl-torch-0.5.2.tar.gz.

File metadata

  • Download URL: rigl-torch-0.5.2.tar.gz
  • Upload date:
  • Size: 12.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.24.0 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.9.0

File hashes

Hashes for rigl-torch-0.5.2.tar.gz
Algorithm Hash digest
SHA256 6ac1144c6dd9e04c8527cf0f7335478cabf5126cf1b41aab282ada6027f8697b
MD5 567fafe47714d451644c5e929208b1ff
BLAKE2b-256 77936c3363657dcf52860f5cfd40f1a94debd466836267d7c622d58cc5c14805

See more details on using hashes here.

File details

Details for the file rigl_torch-0.5.2-py3-none-any.whl.

File metadata

  • Download URL: rigl_torch-0.5.2-py3-none-any.whl
  • Upload date:
  • Size: 8.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.24.0 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.9.0

File hashes

Hashes for rigl_torch-0.5.2-py3-none-any.whl
Algorithm Hash digest
SHA256 41b91f378dc1933ff61357e989472c625e5353104d8b00b7a0c9ddd28b5c4bab
MD5 fd87b3289e58c76de296694318e4ed13
BLAKE2b-256 9e38f87e8892ade00228f9bb12b548dc74560a64f57c9a9b4eedac228d1f9fa6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page