Adds the memory efficient SM3 optimizer to PyTorch.
Project description
PyTorch-SM3
[source] [TensorFlow] [notebook]
Implements the SM3-II adaptive optimization algorithm for PyTorch. This algorithm was designed by Rohan Anil, Vineet Gupta, Tomer Koren, and Yoram Singer and implemented in TensorFlow.
The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method'
(SM3) algorithm is a memory-efficient adaptive optimization algorithm similar
to Adam and Adagrad with greatly reduced memory usage for history tensors.
For an n x m
matrix, Adam and Adagrad use O(nm)
memory for history
tensors, while SM3 uses O(n+m)
due to the chosen cover. In general, a tensor
of shape (n_1, n_2, ..., n_k)
optimized using Adam will use O(prod n_i)
memory for storage tensors, while the optimization using SM3 will use
O(sum n_i)
memory. Despite storing fewer parameters, this optimization
algorithm manages to be comparably effective.
This advantage drastically shrinks when momentum > 0
. The momentum is
tracked using a tensor of the same shape as the tensor being optimized. With
momentum, SM3 will use just over half as much memory as Adam, and a bit more
than Adagrad.
If the gradient is sparse, then the optimization algorithm will use O(n_1)
memory as there is only a row cover. The value of momentum
is ignored in
this case.
Installing
To install with pip
, you can use pip install torch-SM3
. Alternatively,
clone the repository and run python setup.py sdist
and install using the
generated source package.
Usage
After installing, import the optimizer using from SM3 import SM3
. The SM3
optimizer that is imported can be used exactly the same way a PyTorch
optimizer. For example, the optimizer can be constructed using
opt = SM3(model.parameters())
with parameter updates being applied using
opt.step()
.
Implementation Differences
The algorithm presented by the original authors mentions that the optimization
algorithm can be modified to use exponential moving averages. I incorporated
this into the optimizer. If beta = 0
, then the accumulated gradient squares
method (i.e. the default SM3 method) is used. If beta > 0
, then the updates
use exponential moving averages instead. The authors found that beta = 0
was superior for their experiments in translation and language models.
Requirements
The requirements given in requirements.txt
are not the absolute minimum -
the optimizer may function for earlier versions of PyTorch than 1.4. However,
these versions are not tested against. Furthermore, a change in the backend
C++
signatures means that the current version of this package may not run
against earlier versions of PyTorch.
Wisdom from authors
Their full advice can be seen in the sources above. Here are two points they emphasize and how to incorporate them.
Learning rate warm-up
They prefer using a learning rate that quadratically ramps up to the
full learning rate. This is done in the notebook linked above by using the
LambdaLR
class. After creating the optimizer, you can use the following:
lr_lambda = lambda epoch: min(1., (epoch / warm_up_epochs) ** 2)
scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
The authors advocate for this as they found that the early gradients were typically very large in magnitude. By using a warm-up, the accumulated gradients are not dominated by the first few updates. After this warm-up, they do not adjust the learning rate.
Learning rate decay
Polyak averaging
can be useful for training models as the moving average of the parameters
can produce better results than the parameters themselves. As this can be
costly in memory, an alternative they present is to ramp the learning rate
decay to 0 in the last 10% of steps. This can also be achieved using the
LambdaLR
class with the following lambda
function:
lr_lambda = lambda epoch: min(1., (total_epochs - epoch) / (0.1 * total_epochs))
To incorporate both warm-up and decay, we can combine the two functions:
lr_lambda = lambda epoch: min(1., (epoch / (warm_up_epochs)) ** 2, (epochs - epoch) / (0.1 * epochs))
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch-SM3-0.1.0.tar.gz
.
File metadata
- Download URL: torch-SM3-0.1.0.tar.gz
- Upload date:
- Size: 5.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.1.1 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.8.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2baec737b8144f70ec83f1831d89e53d72965c4637689ed895738a21568acfb0 |
|
MD5 | dea0dbcd88c2715f4fee67f3cfb7ad39 |
|
BLAKE2b-256 | caab4af17fb552caacde153c99e5098b42a31e264a38b8ee2ca7685b31bddb5d |
File details
Details for the file torch_SM3-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: torch_SM3-0.1.0-py3-none-any.whl
- Upload date:
- Size: 9.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.1.1 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.8.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5c214703aaac68d5c4511badc8bbc0d093c690e8524f156462ad98fadfd9509f |
|
MD5 | dd9adae761807f402ccc8dd5a8732d84 |
|
BLAKE2b-256 | 90a75f3404c066907f3c92ed6d653980c7ff18cc587758bafaca7fdc0b658612 |