Skip to main content

Python package for RMSGD

Project description

RMSGD: Augmented SGD Optimizer

Official PyTorch implementation of the RMSGD optimizer from: Exploiting Explainable Metrics for Augmented SGD


We propose new explainability metrics that measure the redundant information in a network's layers and exploit this information to augment the Stochastic Gradient Descent (SGD) optimizer by adaptively adjusting the learning rate in each layer. We call this new optimizer RMSGD. RMSGD is fast, performs better than existing sota, and generalizes well across experimental configurations.

Contents

This repository + branch contains the standalone optimizer, which is pip installable. Equally, you could copy the contents of src/rmsgd into your local repository and use the optimizer as is.

For all code relating to our paper and to replicate those experiments, see the paper branch

Installation

You can install rmsgd using pip install rmsgd, or equally:

git clone https://github.com/mahdihosseini/RMSGD.git
cd RMSGD
pip install .

Usage

RMSGD can be used like any other optimizer, with one additional step:

from rmsgd import RMSGD
...
optimizer = RMSGD(...)
...
for input in data_loader:
    optimizer.zero_grad()
    output = network(input)
    optimizer.step()
optimizer.epoch_step()

Simply, you must call .epoch_step() at the end of each epoch to update the analysis of the network layers.

Citation

@Article{hosseini2022rmsgd,
  author  = {},
  title   = {},
  journal = {},
  year    = {},
}

License

This project is released under the MIT license. Please see the LICENSE file for more information.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rmsgd-1.0.1.tar.gz (9.6 kB view details)

Uploaded Source

Built Distribution

rmsgd-1.0.1-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file rmsgd-1.0.1.tar.gz.

File metadata

  • Download URL: rmsgd-1.0.1.tar.gz
  • Upload date:
  • Size: 9.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/33.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9

File hashes

Hashes for rmsgd-1.0.1.tar.gz
Algorithm Hash digest
SHA256 221d41eb8ce0659816c6a49438f7d981ad811a7cba9198bdc3c4b9c8dc05de60
MD5 e40cd4452def2dcec7431392835d1888
BLAKE2b-256 48f1852d7615badcc966fe720552cf00cce70e2afd83e22ef58dcc315e0e9137

See more details on using hashes here.

File details

Details for the file rmsgd-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: rmsgd-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/33.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9

File hashes

Hashes for rmsgd-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5fb20c4b0bdc9dc2628e946a4c88c316753bb8ffd50156c89bf084aa3daf2390
MD5 842a9038343da50edbe5bf1f49b9f9d1
BLAKE2b-256 7135c9f24132d85006009407fd77c86df569653497d8a485d483e6ef22ba9d8b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page