Skip to main content

Heuristically optimize learning rates in neural networks through subsampling loss differentials.

Project description

Build PyPi Downloads License

Heuristically optimize learning rates in neural networks through subsampling loss differentials.


Auto LR Tuner

The learning rate is often one of the most important hyperparameters when training neural networks. Adaptive gradient-based methods (e.g. ADAM), decreasing learning rates based on a validation set, and cosine annealing are common tricks done in practice to improve convergence.

This library provides a simple algorithm for automatically tuning learning rates for TensorFlow Keras models. Importantly, these methods are largely based on heuristics and my own experience training neural networks and there are no formal results.

Algorithm Details

The main idea behind this implementation is to estimate the optimal learning rate by trying to determine the steepness of the loss surface. Intuitively, a very small learning rate leads to almost no change in losses, while an excessively large learning rate can overshoot local minima and even increase the loss.

We start with a user-specified grid $\Theta = [lr_{min}, lr_{max}]$ of potential learning rates to search over. Next, we subsample $m$ observations from the data, and evaluate $\Delta L := L_{post} - L_{pre}$ where $L_{pre}$ is the baseline loss value, and $L_{post}$ is the loss after a single backpropagation step for a given learning rate $\theta \in \Theta$. The subsampling process is repeated $M$ times to approximate $E_{M}[\Delta L]$ and construct confidence intervals based on the subsampled distribution $\hat{P}_{\Delta L, M}$.

In this interpretation of the problem, the "optimal" learning rate is the one that consistently decreases the loss across the dataset without excessively high variance. For example in the plot below, a small learning rate near-zero may be very slow to converge while a learning rate of one actually increases the loss on average. Here the optimal rate may be something close to 0.6.

Plot

Getting Started

Install from PyPi:

$ pip install autolrtuner

Dependencies

  • python >= 3.8
  • tensorflow >= 2.5
  • numpy
  • pandas
  • scipy
  • tqdm
  • matplotlib (Optional)

Example

from autolrtuner import AutoLRTuner

# Compiled TensorFlow Keras model
model 

# Run the tuning algorithm
lrTuner = AutoLRTuner(model)
lrTuner.Tune(X, Y, num_subsamples=100, num_evals=100, batch_size=32)

# Get the optimal learning rate
bestLR = lrTuner.GetBestLR(method='mean')

# Plot the loss differentials
lrTuner.Plot()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

autolrtuner-1.0.1.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

autolrtuner-1.0.1-py3-none-any.whl (7.5 kB view details)

Uploaded Python 3

File details

Details for the file autolrtuner-1.0.1.tar.gz.

File metadata

  • Download URL: autolrtuner-1.0.1.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for autolrtuner-1.0.1.tar.gz
Algorithm Hash digest
SHA256 0903aa4a1dbda6d84e9dbd28f83eb5884fe84b02c6c0ed4faf74cee20b2312e3
MD5 5ca90f4f0253811f5a7705e7aaee767c
BLAKE2b-256 8ca2197ed27736fd5c2c38a3bcfb6cb19148a255f8ecfceb3e602c579750cbd2

See more details on using hashes here.

File details

Details for the file autolrtuner-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: autolrtuner-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 7.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for autolrtuner-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 2a08df9d9b36cb7c47b15cb7acb3bbfcae73aab7014e8a115deec9a23ad3f504
MD5 e4371136d1f6ceeb4aafa0a05524960b
BLAKE2b-256 a8f0c1f81471becb40f4b0ca3d14ccc163739e745701ed78f526064bfddacf19

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page