MoMo: Momentum Models for Adaptive Learning Rates
Pytorch implementation of MoMo methods. Adaptive learning rates for SGD with momentum (SGD-M) and Adam.
You can install the package with
pip install momo-opt
Import the optimizers in Python with
from momo import Momo
opt = Momo(model.parameters(), lr=1)
from momo import MomoAdam
opt = MomoAdam(model.parameters(), lr=1e-2)
Note that Momo needs access to the value of the batch loss.
.step() method, you need to pass either
- the loss tensor (when backward has already been done) to the argument
- or a callable
closureto the argument
closurethat computes gradients and returns the loss.
def compute_loss(output, labels):
loss = criterion(output, labels)
# in each training step, use:
closure = lambda: compute_loss(output,labels)
For more details, see a full example script.
ResNet110 for CIFAR100
ResNet20 for CIFAR10
In general, if you expect SGD-M to work well on your task, then use Momo. If you expect Adam to work well on your problem, then use MomoAdam.
- The option
weight_decayare the same as in standard optimizers. As Momo and MomoAdam automatically adapt the learning rate, you should get good preformance without heavy tuning of
lrand setting a schedule. Setting
lrconstant should work fine. For Momo, our experiments work well with
lr=1, for MomoAdam
lr=1e-2(or slightly smaller) should work well.
One of the main goals of Momo optimizers is to reduce the tuning effort for the learning-rate schedule and get good performance for a wide range of learning rates.
For Momo, the argument
betarefers to the momentum parameter. The default is
beta=0.9. For MomoAdam,
(beta1,beta2)have the same role as in Adam.
lbrefers to a lower bound of your loss function. In many cases,
lb=0will be a good enough estimate. If your loss converges to a large positive number (and you roughly know the value), then set
lbto this value (or slightly smaller).
If you can not estimate a lower bound before training, use the option
use_fstar=True. This will activate an online estimation of the lower bound.
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.