Skip to main content

Heuristically optimize learning rates in neural networks through subsampling loss differentials.

Project description

Build PyPi Downloads License

Heuristically optimize learning rates in neural networks through subsampling loss differentials.


Auto LR Tuner

The learning rate is often one of the most important hyperparameters when training neural networks. Adaptive gradient-based methods (e.g. ADAM), decreasing learning rates based on a validation set, and cosine annealing are common tricks done in practice to improve convergence.

This library provides a simple algorithm for automatically tuning learning rates for TensorFlow Keras models. Importantly, these methods are largely based on heuristics and my own experience training neural networks and there are no formal results.

Algorithm Details

The main idea behind this implementation is to estimate the optimal learning rate by trying to determine the steepness of the loss surface. Intuitively, a very small learning rate leads to almost no change in losses, while an excessively large learning rate can overshoot local minima and even increase the loss.

We start with a user-specified grid $\Theta = [lr_{min}, lr_{max}]$ of potential learning rates to search over. Next, we subsample $m$ observations from the data, and evaluate $\Delta L := L_{post} - L_{pre}$ where $L_{pre}$ is the baseline loss value, and $L_{post}$ is the loss after a single backpropagation step for a given learning rate $\theta \in \Theta$. The subsampling process is repeated $M$ times to approximate $E_{M}[\Delta L]$ and construct confidence intervals based on the subsampled distribution $\hat{P}_{\Delta L, M}$.

In this interpretation of the problem, the "optimal" learning rate is the one that consistently decreases the loss across the dataset without excessively high variance. For example in the plot below, a small learning rate near-zero may be very slow to converge while a learning rate of one actually increases the loss on average. Here the optimal rate may be something close to 0.125.

Plot

Getting Started

Install from PyPi:

$ pip install autolrtuner

Dependencies

  • python >= 3.8
  • tensorflow >= 2.5
  • numpy
  • pandas
  • scipy
  • tqdm
  • matplotlib (Optional)

Example

from autolrtuner import AutoLRTuner

# Compiled TensorFlow Keras model
model 

# Run the tuning algorithm
lrTuner = AutoLRTuner(model)
lrTuner.Tune(X, Y, num_subsamples=100, num_evals=100, batch_size=32)

# Get the optimal learning rate
bestLR = lrTuner.GetBestLR(method='mean')

# Plot the loss differentials
lrTuner.Plot()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

autolrtuner-1.0.2.tar.gz (7.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

autolrtuner-1.0.2-py3-none-any.whl (7.5 kB view details)

Uploaded Python 3

File details

Details for the file autolrtuner-1.0.2.tar.gz.

File metadata

  • Download URL: autolrtuner-1.0.2.tar.gz
  • Upload date:
  • Size: 7.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for autolrtuner-1.0.2.tar.gz
Algorithm Hash digest
SHA256 dad8a89e612565cb9d1629899a0636d42e4b68e42739dcc2ee5207b5886d5928
MD5 b17dfafbc540898cf7ebc291a8683d8f
BLAKE2b-256 633f3b8c22846f5b20423ad4dc8cf1ab6e93b68a09c5be5720bd0d9dc1a447af

See more details on using hashes here.

File details

Details for the file autolrtuner-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: autolrtuner-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 7.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for autolrtuner-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 ab50b3be740128a7cd226f5f4845d4ad77715623b24166edb9c550ab5c5dc1bb
MD5 e8de7ff644c89300bcd8a4180dd0c52f
BLAKE2b-256 64af901b3d0906ff4dcceaaffcad023833b5d0fa495da0341f5f73c8e75f5ee4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page