Skip to main content

Python package for RMSGD

Project description

RMSGD: Augmented SGD Optimizer

Official PyTorch implementation of the RMSGD optimizer from: Exploiting Explainable Metrics for Augmented SGD


We propose new explainability metrics that measure the redundant information in a network's layers and exploit this information to augment the Stochastic Gradient Descent (SGD) optimizer by adaptively adjusting the learning rate in each layer. We call this new optimizer RMSGD. RMSGD is fast, performs better than existing sota, and generalizes well across experimental configurations.

Contents

This repository + branch contains the standalone optimizer, which is pip installable. Equally, you could copy the contents of src/rmsgd into your local repository and use the optimizer as is.

For all code relating to our paper and to replicate those experiments, see the paper branch

Installation

You can install rmsgd using pip install rmsgd, or equally:

git clone https://github.com/mahdihosseini/RMSGD.git
cd RMSGD
pip install .

Usage

RMSGD can be used like any other optimizer, with one additional step:

...
optimizer = RMSGD(...)
...
for input in data_loader:
    optimizer.zero_grad()
    output = network(input)
    optimizer.step()
optimizer.epoch_step()

Simply, you must call .epoch_step() at the end of each epoch to update the analysis of the network layers.

Citation

@Article{hosseini2022rmsgd,
  author  = {},
  title   = {},
  journal = {},
  year    = {},
}

License

This project is released under the MIT license. Please see the LICENSE file for more information.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rmsgd-1.0.0.tar.gz (9.9 kB view details)

Uploaded Source

Built Distribution

rmsgd-1.0.0-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file rmsgd-1.0.0.tar.gz.

File metadata

  • Download URL: rmsgd-1.0.0.tar.gz
  • Upload date:
  • Size: 9.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/33.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9

File hashes

Hashes for rmsgd-1.0.0.tar.gz
Algorithm Hash digest
SHA256 2396023e6f1c702b340832495d0d6fdfd1033ebb56029be68c83a012c3423867
MD5 a6ee74b8a0c4d73c6a7579df11963bdc
BLAKE2b-256 324d2fed346c52027c9ec9b41949a558de4c14a7353cd486c04c06c5fbb49116

See more details on using hashes here.

File details

Details for the file rmsgd-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: rmsgd-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/33.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9

File hashes

Hashes for rmsgd-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5725adb08de2e84d18ded7c571247ee4ba4689f4a4d3ad8d105508d690d0d2ca
MD5 26f5b635f257ee31873c0a243c32f7b7
BLAKE2b-256 5201ecd76022e125d846417ccc6171344aa70e17176d324c70cb67d4bb185372

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page