Skip to main content

Easy to use optimizers in with adaptive gradient clipping. Written in PyTorch.

Project description

AGC Optimizers

A small lib for using adaptive gradient clipping in your optimizer. Currently PyTorch only.

News

Sep 14, 2021

  • Add AdamW, Adam, SGD and RMSprop with AGC
  • Add first comparsion between optimizers with and without AGC based on CIFAR10

Introduction

Brock et al. introduced 2021 a new clipping technique in order to increase stability of large batch training and high learning rates in their Normalizer-Free Networks (NFNet), the adaptive gradient clipping. This clipping method is not implemented in leading frameworks, thus I provide optimizers which are capable of AGC.

Installation

pip install agc_optims

Usage

To be consistent with PyTorch all arguments of the optimizer remain the same as in the standard. Only two parameters are added for the AGC:

  • clipping : Hyperparameter for the clipping of the parameter. Default value 1e-2, smaller batch sizes demand a higher clipping parameter
  • agc_eps : Term used in AGC to prevent grads clipped to zero, default value 1e-3

SGD

from agc_optims.optim import SGD_AGC

net = Net() # your model

optimizer = SGD_AGC(net.parameters(), lr=0.01, momentum=0.9, clipping=0.16)

Adam

from agc_optims.optim import Adam_AGC

net = Net() # your model

optimizer = Adam_AGC(net.parameters(), lr=0.001, weight_decay=1e-4, clipping=0.16)

AdamW

from agc_optims.optim import AdamW_AGC

net = Net() # your model

optimizer = AdamW_AGC(net.parameters(), lr=0.001, weight_decay=1e-4, clipping=0.16)

RMSprop

from agc_optims.optim import RMSprop_AGC

net = Net() # your model

optimizer = RMSprop_AGC(net.parameters(), lr=0.001, clipping=0.16)

Now you can use the optimizer just like their non-AGC counterparts.

Comparison

The following comparison shows that for batch sizes 64 and 128 Adam with AGC performs better than the normal Adam. SGD is unfortunately worse with AGC, but the batch size is also very small compared to the NFNet paper. This requires more comparisons with higher batch sizes and also on other data sets. RMSprop is also better at both batch sizes with AGC than without. The learning rate was left at the default value for all optimizers and the scripts in the performance_tests folder were used as the test environment.

Batch Size 64 - SGD Accuracy on Cifar10 Batch Size 64 - SGD Loss on Cifar10
Batch Size 128 - SGD Accuracy on Cifar10 Batch Size 128 - SGD Loss on Cifar10
Batch Size 64 - Adam Accuracy on Cifar10 Batch Size 64 - Adam Loss on Cifar10
Batch Size 128 - Adam Accuracy on Cifar10 Batch Size 128 - Adam Loss on Cifar10
Batch Size 64 - RMSProp Accuracy on Cifar10 Batch Size 64 - RMSProp Loss on Cifar10
Batch Size 128 - RMSProp Accuracy on Cifar10 Batch Size 128 - RMSProp Loss on Cifar10

As a little treat, I have also compared the speed of the optimizer with and without AGC to see whether this greatly increases training times.

Batch Size 128 - RMSProp Accuracy on Cifar10 Batch Size 128 - RMSProp Loss on Cifar10

To Do

  • Add first comparsion based on CIFAR10 with a small CNN
  • Add comparsion with higher batch sizes (256,512,1024)
  • Add tests for each optimizer
  • Clipping == 0 no AGC
  • Add comparsion based on CIFAR100 with a small CNN
  • Add comparsion based on CIFAR10/100 with ResNet
  • Add comparsion with ImageNet (I do not have enough GPU-Power currently if someone provides some tests I would be grateful)
  • Add all optimizer included in PyTorch
  • Support of other frameworks than PyTorch
  • Add first comparsion based on CIFAR with a small CNN

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

agc_optims-0.0.1.tar.gz (10.6 kB view details)

Uploaded Source

Built Distribution

agc_optims-0.0.1-py3-none-any.whl (14.4 kB view details)

Uploaded Python 3

File details

Details for the file agc_optims-0.0.1.tar.gz.

File metadata

  • Download URL: agc_optims-0.0.1.tar.gz
  • Upload date:
  • Size: 10.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.5

File hashes

Hashes for agc_optims-0.0.1.tar.gz
Algorithm Hash digest
SHA256 bb12df954df76c9b049672f4d17fb301fb103b542c0c08db4d43d7c7c8a3e074
MD5 634553e69540673180944d1d0b67d803
BLAKE2b-256 c47529f07d5661fd819d898bd748f3fd9883066564f9fa241cf4435b2e1a63ca

See more details on using hashes here.

File details

Details for the file agc_optims-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: agc_optims-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 14.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.5

File hashes

Hashes for agc_optims-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ce2bea663123c2c2f6fac0eaaf19e87b1f126766818a787f300e576deffd02ba
MD5 070fabd83d54e2ec9e14dd49a0d78dd4
BLAKE2b-256 ffdad1feedda932ed80752b462a3539b04be40c313f923b6f96f5dbb325f0bc3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page