ReinMax Algorithm
Project description
ReinMax
Beyond Straight-Through
Straight-Through • ReinMax • How To Use • Examples • Citation • License
ReinMax achieves second-order accuracy and is as fast as the original Straight-Through, which has first-order accuracy.
Straight-Through and How It Works
Straight-Through (as below) bridges discrete variables (y_hard
) and back-propagation.
y_soft = theta.softmax()
# one_hot_multinomial is a non-differentiable function
y_hard = one_hot_multinomial(y_soft)
# with straight-through, the derivative of s_hard will
# act as if you had `p_soft` in the forward
y_hard = y_soft - y_soft.detach() + y_hard
It is a long-standing mystery on how straight-through works, lefting doubts on many problems like whether we should use:
p_soft - p_soft.detach()
,(theta/tau).softmax() - (theta/tau).softmax().detach()
,- or what?
Better Performance with Negligible Computation Overheads
We reveal that Straight-Through works as a special case of the forward Euler method, a numerical methods with first-order accuracy. Inspired by Heun's Method, a numerical method achieving second-order accuracy without requiring Hession or other second-order derivatives, we propose ReinMax, which approximates gradient with second-order accuracy with negligible computation overheads.
How to use?
install
pip install reinmax
enjoy
from reinmax import reinmax
...
def forward(self, ...):
...
- y_soft = theta.softmax()
- y_hard = one_hot_multinomial(y_soft)
- y_hard = y_soft - y_soft.detach() + y_hard
+ y_hard, y_soft = reinmax(theta)
...
Examples
Citation
Please cite the following papers if you found our model useful. Thanks!
Liyuan Liu, Xiaodong Liu, Jianfeng Gao, Weizhu Chen, and Jiawei Han (2020). Understanding the Difficulty of Training Transformers. Proc. 2020 Conf. on Empirical Methods in Natural Language Processing (EMNLP'20).
@inproceedings{liu2020admin,
title={Understanding the Difficulty of Training Transformers},
author = {Liu, Liyuan and Liu, Xiaodong and Gao, Jianfeng and Chen, Weizhu and Han, Jiawei},
booktitle = {Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP 2020)},
year={2020}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file reinmax-0.1.0.tar.gz
.
File metadata
- Download URL: reinmax-0.1.0.tar.gz
- Upload date:
- Size: 3.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 69fb22f481e9e6c49ed6b912e7855b04772ce85d1f2ded23c0be863ab37d74fe |
|
MD5 | 3eb5cd3642f3d79bc0309481d2038373 |
|
BLAKE2b-256 | 032f44f4dc51b96c499de697e12557e32b1051328fdc632c7cdc7649466803a8 |
File details
Details for the file reinmax-0.1.0-py2.py3-none-any.whl
.
File metadata
- Download URL: reinmax-0.1.0-py2.py3-none-any.whl
- Upload date:
- Size: 3.8 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2cf8cbb546989ba672c6e5ef2f25ffc84c8e3269dc750685e2cb7083dbc6866a |
|
MD5 | cae97916a88ba2799f3d75ff08ecf2bd |
|
BLAKE2b-256 | 8f766beafd121bfa58fb5b1776a14475765f589135fd214558d7b9f255854fbf |