Skip to main content

Pytorch implementation of the learning rate range test

Project description

PyTorch learning rate finder

A PyTorch implementation of the learning rate range test detailed in Cyclical Learning Rates for Training Neural Networks by Leslie N. Smith and the tweaked version used by fastai.

The learning rate range test is a test that provides valuable information about the optimal learning rate. During a pre-training run, the learning rate is increased linearly or exponentially between two boundaries. The low initial learning rate allows the network to start converging and as the learning rate is increased it will eventually be too large and the network will diverge.

Typically, a good static learning rate can be found half-way on the descending loss curve. In the plot below that would be lr = 0.002.

For cyclical learning rates (also detailed in Leslie Smith's paper) where the learning rate is cycled between two boundaries (base_lr, max_lr), the author advises the point at which the loss starts descending and the point at which the loss stops descending or becomes ragged for base_lr and max_lr respectively. In the plot below, base_lr = 0.0002 and max_lr=0.2.

Learning rate range test

Installation

Python 2.7 and above: pip install torch-lr-finder

Implementation details and usage

Tweaked version from fastai

Increases the learning rate in an exponential manner and computes the training loss for each learning rate. lr_finder.plot() plots the training loss versus logarithmic learning rate.

from torch_lr_finder import LRFinder

model = ...
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-7, weight_decay=1e-2)
lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(trainloader, end_lr=100, num_iter=100)
lr_finder.plot()

Leslie Smith's approach

Increases the learning rate linearly and computes the evaluation loss for each learning rate. lr_finder.plot() plots the evaluation loss versus learning rate. This approach typically produces more precise curves because the evaluation loss is more susceptible to divergence but it takes significantly longer to perform the test, especially if the evaluation dataset is large.

from torch_lr_finder import LRFinder

model = ...
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.1, weight_decay=1e-2)
lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(trainloader, end_lr=1, num_iter=100, step_mode="linear")
lr_finder.plot(log_lr=False)

Notes

  • Examples for CIFAR10 and MNIST can be found in the examples folder.
  • LRFinder.range_test() will change the model weights and the optimizer parameters. Both can be restored to their initial state with LRFinder.reset().
  • The learning rate and loss history can be accessed through lr_finder.history. This will return a dictionary with lr and loss keys.
  • When using step_mode="linear" the learning rate range should be within the same order of magnitude.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-lr-finder-0.0.1.tar.gz (6.0 kB view details)

Uploaded Source

Built Distribution

torch_lr_finder-0.0.1-py3-none-any.whl (7.4 kB view details)

Uploaded Python 3

File details

Details for the file torch-lr-finder-0.0.1.tar.gz.

File metadata

  • Download URL: torch-lr-finder-0.0.1.tar.gz
  • Upload date:
  • Size: 6.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.14.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/40.6.2 requests-toolbelt/0.9.1 tqdm/4.35.0 CPython/3.6.8

File hashes

Hashes for torch-lr-finder-0.0.1.tar.gz
Algorithm Hash digest
SHA256 3d1f91f0232f069325b456348b97d7936921f77393cef10e8253b3cb7cc2932d
MD5 f5f590e69484e6fb07314582669ea875
BLAKE2b-256 5dcd91f910b0b05cb72ad66db3b8a82b86f8e1ab4241398eef8dc4dfd033a7eb

See more details on using hashes here.

File details

Details for the file torch_lr_finder-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: torch_lr_finder-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 7.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.14.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/40.6.2 requests-toolbelt/0.9.1 tqdm/4.35.0 CPython/3.6.8

File hashes

Hashes for torch_lr_finder-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7f690a2e093b10c0e1935e013572c4be7225820c5be40af56139a6c28626db9e
MD5 134a574f9dab12ebc005ddff441694ec
BLAKE2b-256 627c397078ef7ca83880a35037285fd9d1f0e70ab5414db3a6f81e868c7230e4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page