MoMo: Momentum Models for Adaptive Learning Rates
Project description
MoMo
Pytorch implementation of MoMo methods. Adaptive learning rates for SGD with momentum (SGD-M) and Adam.
Installation
You can install the package with
pip install momo-opt
Usage
Import the optimizers in Python with
from momo import Momo
opt = Momo(model.parameters(), lr=1)
or
from momo import MomoAdam
opt = MomoAdam(model.parameters(), lr=1e-2)
Note that Momo needs access to the value of the batch loss.
In the .step()
method, you need to pass either
- the loss tensor (when backward has already been done) to the argument
loss
- or a callable
closure
to the argumentclosure
that computes gradients and returns the loss.
For example:
def compute_loss(output, labels):
loss = criterion(output, labels)
loss.backward()
return loss
# in each training step, use:
closure = lambda: compute_loss(output,labels)
opt.step(closure=closure)
For more details, see a full example script.
Examples
ResNet110 for CIFAR100
ResNet20 for CIFAR10
Recommendations
In general, if you expect SGD-M to work well on your task, then use Momo. If you expect Adam to work well on your problem, then use MomoAdam.
- The option
lr
andweight_decay
are the same as in standard optimizers. As Momo and MomoAdam automatically adapt the learning rate, you should get good preformance without heavy tuning oflr
and setting a schedule. Settinglr
constant should work fine. For Momo, our experiments work well withlr=1
, for MomoAdamlr=1e-2
(or slightly smaller) should work well.
One of the main goals of Momo optimizers is to reduce the tuning effort for the learning-rate schedule and get good performance for a wide range of learning rates.
-
For Momo, the argument
beta
refers to the momentum parameter. The default isbeta=0.9
. For MomoAdam,(beta1,beta2)
have the same role as in Adam. -
The option
lb
refers to a lower bound of your loss function. In many cases,lb=0
will be a good enough estimate. If your loss converges to a large positive number (and you roughly know the value), then setlb
to this value (or slightly smaller). -
If you can not estimate a lower bound before training, use the option
use_fstar=True
. This will activate an online estimation of the lower bound.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file momo-opt-0.1.0.tar.gz
.
File metadata
- Download URL: momo-opt-0.1.0.tar.gz
- Upload date:
- Size: 7.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4c4e9336652d68d0cad4dfddbc8f7a38acaa9e4e6fd8e83262294d547f737352 |
|
MD5 | 815578882ce61c45a029b03a5dcdecee |
|
BLAKE2b-256 | ee0923651f542e8ac27e2ae63aa7b38365ff1d6b289b39254ada4c58f39e0e57 |
File details
Details for the file momo_opt-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: momo_opt-0.1.0-py3-none-any.whl
- Upload date:
- Size: 8.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 90648b8189bfc34cf183d8f2f286baa78c2ca1f0541ef332d3ff13cde77728c1 |
|
MD5 | c451de4700cc5a3029d310d80c595d79 |
|
BLAKE2b-256 | f3f604626a49f15cb3608f02ab84bcebdd7ca647f0b92fef7c2c5fe6c57eb4cd |