Skip to main content

ReinMax Algorithm

Project description

PyTorch PyPI - Python Version GitHub PyPI

ReinMax

Beyond Straight-Through

Straight-ThroughReinMaxHow To UseExamplesCitationLicense

ReinMax achieves second-order accuracy and is as fast as the original Straight-Through, which has first-order accuracy.

Straight-Through and How It Works

Straight-Through (as below) bridges discrete variables (y_hard) and back-propagation.

y_soft = theta.softmax()

# one_hot_multinomial is a non-differentiable function
y_hard = one_hot_multinomial(y_soft) 

# with straight-through, the derivative of s_hard will
# act as if you had `p_soft` in the forward
y_hard = y_soft - y_soft.detach() + y_hard 

It is a long-standing mystery on how straight-through works, lefting doubts on many problems like whether we should use:

  • p_soft - p_soft.detach(),
  • (theta/tau).softmax() - (theta/tau).softmax().detach(),
  • or what?

Better Performance with Negligible Computation Overheads

We reveal that Straight-Through works as a special case of the forward Euler method, a numerical methods with first-order accuracy. Inspired by Heun's Method, a numerical method achieving second-order accuracy without requiring Hession or other second-order derivatives, we propose ReinMax, which approximates gradient with second-order accuracy with negligible computation overheads.

How to use?

install

pip install reinmax

enjoy

from reinmax import reinmax
...

def forward(self, ...):
...
- y_soft = theta.softmax()
- y_hard = one_hot_multinomial(y_soft) 
- y_hard = y_soft - y_soft.detach() + y_hard 
+ y_hard, y_soft = reinmax(theta)
...

Examples

Citation

Please cite the following papers if you found our model useful. Thanks!

Liyuan Liu, Xiaodong Liu, Jianfeng Gao, Weizhu Chen, and Jiawei Han (2020). Understanding the Difficulty of Training Transformers. Proc. 2020 Conf. on Empirical Methods in Natural Language Processing (EMNLP'20).

@inproceedings{liu2020admin,
  title={Understanding the Difficulty of Training Transformers},
  author = {Liu, Liyuan and Liu, Xiaodong and Gao, Jianfeng and Chen, Weizhu and Han, Jiawei},
  booktitle = {Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP 2020)},
  year={2020}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

reinmax-0.1.0.tar.gz (3.9 kB view hashes)

Uploaded Source

Built Distribution

reinmax-0.1.0-py2.py3-none-any.whl (3.8 kB view hashes)

Uploaded Python 2 Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page