A PyTorch Extension for Learning Rate Warmup
Project description
A PyTorch Extension for Learning Rate Warmup
This library contains PyTorch implementations of the warmup schedules described in On the adequacy of untuned warmup for adaptive optimization.
Installation
Make sure you have Python 3.6+ and PyTorch 1.1+. Then, run the following command:
python setup.py install
or
pip install -U pytorch_warmup
Usage
Sample Codes
The scheduled learning rate is dampened by the multiplication of the warmup factor:
Approach 1
When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used as follows:
import torch
import pytorch_warmup as warmup
optimizer = torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)
num_steps = len(dataloader) * num_epochs
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
for epoch in range(1,num_epochs+1):
for batch in dataloader:
optimizer.zero_grad()
loss = ...
loss.backward()
optimizer.step()
with warmup_scheduler.dampening():
lr_scheduler.step()
If you want to use the learning rate schedule "chaining" which is supported for PyTorch 1.4.0 or above, you may simply give a code of learning rate schedulers as a suite of the with
statement:
lr_scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
for epoch in range(1,num_epochs+1):
for batch in dataloader:
...
optimizer.step()
with warmup_scheduler.dampening():
lr_scheduler1.step()
lr_scheduler2.step()
Approach 2
When the learning rate schedule uses the epoch number, the warmup schedule can be used as follows:
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[num_epochs//3], gamma=0.1)
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
for epoch in range(1,num_epochs+1):
for iter, batch in enumerate(dataloader):
optimizer.zero_grad()
loss = ...
loss.backward()
optimizer.step()
if iter < len(dataloader)-1:
with warmup_scheduler.dampening():
pass
with warmup_scheduler.dampening():
lr_scheduler.step()
Warmup Schedules
Manual Warmup
The warmup factor w(t)
depends on the warmup period, which must manually be specified, for LinearWarmup
and ExponentialWarmup
.
Linear
w(t) = min(1, t / warmup_period)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=2000)
Exponential
w(t) = 1 - exp(-t / warmup_period)
warmup_scheduler = warmup.ExponentialWarmup(optimizer, warmup_period=1000)
Untuned Warmup
The warmup period is given by a function of Adam's beta2
parameter for UntunedLinearWarmup
and UntunedExponentialWarmup
.
Linear
warmup_period = 2 / (1 - beta2)
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
Exponential
warmup_period = 1 / (1 - beta2)
warmup_scheduler = warmup.UntunedExponentialWarmup(optimizer)
RAdam Warmup
The warmup factor depends on Adam's beta2
parameter for RAdamWarmup
. Please see the original paper for the details.
warmup_scheduler = warmup.RAdamWarmup(optimizer)
Apex's Adam
The Apex library provides an Adam optimizer tuned for CUDA devices, FusedAdam. The FusedAdam optimizer can be used with the warmup schedulers. For example:
optimizer = apex.optimizers.FusedAdam(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
License
MIT License
Copyright (c) 2019 Takenori Yamamoto
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytorch-warmup-0.1.1.tar.gz
.
File metadata
- Download URL: pytorch-warmup-0.1.1.tar.gz
- Upload date:
- Size: 313.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.7.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c594760b29657a127aa6a8c3424dd0b5068140b3b7d4988118f4a9f3e99b1457 |
|
MD5 | 9e558b9ce50f617f2a93af1b09b7e2af |
|
BLAKE2b-256 | c102a83b1b0379880621c794d043665c927e9a87764d35d2d8b0a4831e68d8c5 |
File details
Details for the file pytorch_warmup-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: pytorch_warmup-0.1.1-py3-none-any.whl
- Upload date:
- Size: 6.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.7.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | eecc4af0975bb181198c0817be145bccb17c7ea09ce3fdf69140f65d8c32b746 |
|
MD5 | e725e1b040a0007d7f37edcf2319dcfc |
|
BLAKE2b-256 | b989adb6809ac3d587a725ff1d0cd79e0a75bc3a20c3fa1476a917026838f1d0 |