Skip to main content

Python package for RMSGD

Project description

RMSGD: Augmented SGD Optimizer

Official PyTorch implementation of the RMSGD optimizer from:

Exploiting Explainable Metrics for Augmented SGD
Mahdi S. Hosseini, Mathieu Tuli, Konstantinos N. Plataniotis
Accepted in IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR2022)


We propose new explainability metrics that measure the redundant information in a network's layers and exploit this information to augment the Stochastic Gradient Descent (SGD) optimizer by adaptively adjusting the learning rate in each layer. We call this new optimizer RMSGD. RMSGD is fast, performs better than existing sota, and generalizes well across experimental configurations.

Contents

This repository + branch contains the standalone optimizer, which is pip installable. Equally, you could copy the contents of src/rmsgd into your local repository and use the optimizer as is.

For all code relating to our paper and to replicate those experiments, see the paper branch

Installation

You can install rmsgd using pip install rmsgd, or equally:

git clone https://github.com/mahdihosseini/RMSGD.git
cd RMSGD
pip install .

Usage

RMSGD can be used like any other optimizer, with one additional step:

from rmsgd import RMSGD
...
optimizer = RMSGD(...)
...
for input in data_loader:
    optimizer.zero_grad()
    output = network(input)
    optimizer.step()
optimizer.epoch_step()

Simply, you must call .epoch_step() at the end of each epoch to update the analysis of the network layers.

Citation

@Article{hosseini2022rmsgd,
  author  = {Hosseini, Mahdi S. and Tuli, Mathieu and Plataniotis, Konstantinos N.},
  title   = {Exploiting Explainable Metrics for Augmented SGD},
  journal = {Accepted in IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year    = {2022},
}

License

This project is released under the MIT license. Please see the LICENSE file for more information.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rmsgd-1.0.2.tar.gz (10.0 kB view details)

Uploaded Source

Built Distribution

rmsgd-1.0.2-py3-none-any.whl (9.0 kB view details)

Uploaded Python 3

File details

Details for the file rmsgd-1.0.2.tar.gz.

File metadata

  • Download URL: rmsgd-1.0.2.tar.gz
  • Upload date:
  • Size: 10.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/33.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9

File hashes

Hashes for rmsgd-1.0.2.tar.gz
Algorithm Hash digest
SHA256 eb78db693f6aa0bacd07c8ee09e5510ad54306e8299ca762c98c482ad3a54f83
MD5 326099d4c6ba2098dac35d45ce955423
BLAKE2b-256 fe8d9b2c1cc36d4787c70f7deff72d06d3eaf38f79947b5444f1f5bef7cd2696

See more details on using hashes here.

File details

Details for the file rmsgd-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: rmsgd-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 9.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/33.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9

File hashes

Hashes for rmsgd-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 54cdf95eeaf7ee299a879c0b58ce1440b63f7e48739a0830aa5e512dc0cc65bd
MD5 824c26b360c745f58a32a61b58f7b5ef
BLAKE2b-256 e02e308de5332c6f290b943b3b1ebf53b86176c425aa05e2d73fbd64fe68e664

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page