Skip to main content

Heuristically optimize learning rates in neural networks through subsampling loss differentials.

Project description

Build PyPi Downloads License

Heuristically optimize learning rates in neural networks through subsampling loss differentials.


Auto LR Tuner

The learning rate is often one of the most important hyperparameters when training neural networks. Adaptive gradient-based methods (e.g. ADAM), decreasing learning rates based on a validation set, and cosine annealing are common tricks done in practice to improve convergence.

This library provides a simple algorithm for automatically tuning learning rates for TensorFlow Keras models. Importantly, these methods are largely based on heuristics and my own experience training neural networks and there are no formal results.

Algorithm Details

The main idea behind this implementation is to estimate the optimal learning rate by trying to determine the steepness of the loss surface. Intuitively, a very small learning rate leads to almost no change in losses, while an excessively large learning rate can overshoot local minima and even increase the loss.

We start with a user-specified grid $\Theta = [lr_{min}, lr_{max}]$ of potential learning rates to search over. Next, we subsample $m$ observations from the data, and evaluate $\Delta L := L_{post} - L_{pre}$ where $L_{pre}$ is the baseline loss value, and $L_{post}$ is the loss after a single backpropagation step for a given learning rate $\theta \in \Theta$. The subsampling process is repeated $M$ times to approximate $\mathbb{E}{M}[\Delta L]$ and construct confidence intervals based on the subsampled distribution $\hat{P}{\Delta L, M}$.

In this interpretation of the problem, the "optimal" learning rate is the one that consistently decreases the loss across the dataset without excessively high variance. For example in the plot below, a small learning rate near-zero may be very slow to converge while a learning rate of one actually increases the loss on average. Here the optimal rate may be something close to 0.6.

Plot

Getting Started

Install from PyPi:

$ pip install autolrtuner

Dependencies

  • python >= 3.7
  • tensorflow >= 2.5
  • numpy
  • pandas
  • scipy
  • tqdm
  • matplotlib (Optional)

Example

from autolrtuner import AutoLRTuner

# Compiled TensorFlow Keras model
model 

# Run the tuning algorithm
lrTuner = AutoLRTuner(model)
lrTuner.Tune(X, Y, num_subsamples=100, num_evals=100, batch_size=32)

# Get the optimal learning rate
bestLR = lrTuner.GetBestLR(method='mean')

# Plot the loss differentials
lrTuner.Plot()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

autolrtuner-1.0.0.tar.gz (7.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

autolrtuner-1.0.0-py3-none-any.whl (7.5 kB view details)

Uploaded Python 3

File details

Details for the file autolrtuner-1.0.0.tar.gz.

File metadata

  • Download URL: autolrtuner-1.0.0.tar.gz
  • Upload date:
  • Size: 7.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for autolrtuner-1.0.0.tar.gz
Algorithm Hash digest
SHA256 d44c691a8eaa794af96ee8bceaaf6ba087f868ad99e67a3ab90c153ce11d5245
MD5 9614845b310e4fad6c1fef862373fdd5
BLAKE2b-256 62d514eb82937f50fb42f6b4ca7159f90dab37ee8fa0cfe59c57a8771ffc121d

See more details on using hashes here.

File details

Details for the file autolrtuner-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: autolrtuner-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 7.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for autolrtuner-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e9229fb4c719378600f53e29c1b1edce3b45f34ebe28df2cfe28f665fe1ea5d4
MD5 c80687e9c0364ca3f3dc9f1a92130907
BLAKE2b-256 5a76e2a18f2a2550210bb5a3bce44b3f84a2a75df69355b284d6a61ce960482e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page