Python package for RMSGD
Project description
RMSGD: Augmented SGD Optimizer
Official PyTorch implementation of the RMSGD optimizer from:
Exploiting Explainable Metrics for Augmented SGD
Mahdi S. Hosseini, Mathieu Tuli, Konstantinos N. Plataniotis
Accepted in IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR2022)
We propose new explainability metrics that measure the redundant information in a network's layers and exploit this information to augment the Stochastic Gradient Descent (SGD) optimizer by adaptively adjusting the learning rate in each layer. We call this new optimizer RMSGD. RMSGD is fast, performs better than existing sota, and generalizes well across experimental configurations.
Contents
This repository + branch contains the standalone optimizer, which is pip installable. Equally, you could copy the contents of src/rmsgd into your local repository and use the optimizer as is.
For all code relating to our paper and to replicate those experiments, see the paper branch
Installation
You can install rmsgd using pip install rmsgd
, or equally:
git clone https://github.com/mahdihosseini/RMSGD.git
cd RMSGD
pip install .
Usage
RMSGD can be used like any other optimizer, with one additional step:
from rmsgd import RMSGD
...
optimizer = RMSGD(...)
...
for input in data_loader:
optimizer.zero_grad()
output = network(input)
optimizer.step()
optimizer.epoch_step()
Simply, you must call .epoch_step()
at the end of each epoch to update the analysis of the network layers.
Citation
@Article{hosseini2022rmsgd,
author = {Hosseini, Mahdi S. and Tuli, Mathieu and Plataniotis, Konstantinos N.},
title = {Exploiting Explainable Metrics for Augmented SGD},
journal = {Accepted in IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year = {2022},
}
License
This project is released under the MIT license. Please see the LICENSE file for more information.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file rmsgd-1.0.2.tar.gz
.
File metadata
- Download URL: rmsgd-1.0.2.tar.gz
- Upload date:
- Size: 10.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/33.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | eb78db693f6aa0bacd07c8ee09e5510ad54306e8299ca762c98c482ad3a54f83 |
|
MD5 | 326099d4c6ba2098dac35d45ce955423 |
|
BLAKE2b-256 | fe8d9b2c1cc36d4787c70f7deff72d06d3eaf38f79947b5444f1f5bef7cd2696 |
File details
Details for the file rmsgd-1.0.2-py3-none-any.whl
.
File metadata
- Download URL: rmsgd-1.0.2-py3-none-any.whl
- Upload date:
- Size: 9.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/33.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 54cdf95eeaf7ee299a879c0b58ce1440b63f7e48739a0830aa5e512dc0cc65bd |
|
MD5 | 824c26b360c745f58a32a61b58f7b5ef |
|
BLAKE2b-256 | e02e308de5332c6f290b943b3b1ebf53b86176c425aa05e2d73fbd64fe68e664 |