MoMo: Momentum Models for Adaptive Learning Rates
Project description
MoMo
Pytorch implementation of MoMo methods. Adaptive learning rates for SGD with momentum (SGD-M) and Adam.
Installation
You can install the package with
pip install momo-opt
Usage
Import the optimizers in Python with
from momo import Momo
opt = Momo(model.parameters(), lr=1)
or
from momo import MomoAdam
opt = MomoAdam(model.parameters(), lr=1e-2)
Note that Momo needs access to the value of the batch loss.
In the .step()
method, you need to pass either
- the loss tensor (when backward has already been done) to the argument
loss
- or a callable
closure
to the argumentclosure
that computes gradients and returns the loss.
For example:
def compute_loss(output, labels):
loss = criterion(output, labels)
loss.backward()
return loss
# in each training step, use:
closure = lambda: compute_loss(output,labels)
opt.step(closure=closure)
For more details, see a full example script.
Examples
ResNet110 for CIFAR100
ResNet20 for CIFAR10
Recommendations
In general, if you expect SGD-M to work well on your task, then use Momo. If you expect Adam to work well on your problem, then use MomoAdam.
- The option
lr
andweight_decay
are the same as in standard optimizers. As Momo and MomoAdam automatically adapt the learning rate, you should get good preformance without heavy tuning oflr
and setting a schedule. Settinglr
constant should work fine. For Momo, our experiments work well withlr=1
, for MomoAdamlr=1e-2
(or slightly smaller) should work well.
One of the main goals of Momo optimizers is to reduce the tuning effort for the learning-rate schedule and get good performance for a wide range of learning rates.
-
For Momo, the argument
beta
refers to the momentum parameter. The default isbeta=0.9
. For MomoAdam,(beta1,beta2)
have the same role as in Adam. -
The option
lb
refers to a lower bound of your loss function. In many cases,lb=0
will be a good enough estimate. If your loss converges to a large positive number (and you roughly know the value), then setlb
to this value (or slightly smaller). -
If you can not estimate a lower bound before training, use the option
use_fstar=True
. This will activate an online estimation of the lower bound.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.