Skip to main content

Pytorch implementation of Generalized Newton's method (GeN), a learning-rate-free and Hessian-informed optimization.

Project description

GeN: Generalized Newton's Method for Learning-Rate-Free Optimization 🚀


Paper: Gradient Descent with Generalized Newton’s Method (ICLR 2024)


📦 Repository Overview

This repository contains the code and examples for Generalized Newton's method as a learning-rate-free optimization. It supports a wide range of models and tasks, including:

  • 🖼️ Image classification (CIFAR10/CIFAR100/ImageNet... datasets with ViT/ResNet models)
  • 📝 Natural language generation (E2E/DART... datasets with GPT2 models)
  • 📊 Natural language understanding (SST2/QNLI/MNLI... datasets with BERT/RoBERTa models)
  • 🕵️‍♂️ Object detection / Instance segmentation
  • 🎯 Recommendation system

Example scripts are provided for each task in the examples/ directory. The core implementation of GeN optimizer can be found in GeN/, which roughly has the same speed and memory cost as the base optimizers.

⚡ Quickstart

🛠️ Installation

Install the package from PyPI:

pip install gen-optim

🏃 Minimal Training Loop

To use GeN in your PyTorch training loop, simply add two lines between backward() and optimizer.step():

from GeN import lr_parabola
optimizer = AdamW(model.parameters(), lr=1e-4)
tr_iter = iter(train_loader)

# Standard training pipeline
loss = F.cross_entropy(model(batch), labels)
loss.backward()
if (batch_idx+1) % lazy_freq == 0:
    lr_parabola(model, optimizer, tr_iter=tr_iter, task='image_cls', scale=scale)
optimizer.step()
optimizer.zero_grad()
  • scale can be used to enable the horizon-aware learning rate (e.g., np.linspace(1,0,epochs+1)).
  • Call lr_parabola infrequently (a.k.a. lazy update) by setting lazy_freq>=4 for efficiency.
  • Different task values need different forward passes. Can be customized.

🧩 Function Overview

The main function is lr_parabola, which adapts the learning rate based on a quadratic curve fitting to the loss landscape, with minimal code changes and computational overhead. This enables learning-rate-free optimization and leverages the Hessian information, like the Newton–Raphson method.

Mathematically, we turn any base optimizer (e.g. SGD or AdamW) to the GeN optimizer by

Update rule

where g_t is the stochastic pre-conditioned gradient, G_t is the oracle gradient and H_t is the oracle Hessian.

To enable the horizon-aware GeN, like cosine or linear decay learning rates, we use hyperparameter-free one-to-zero decay (controlled by `scale`):

Update rule       Update rule

📚 Citation

If you use GeN in your research, please cite:

@inproceedings{bu2024gradient,
  title={Gradient descent with generalized newton’s method},
  author={Bu, Zhiqi and Xu, Shiyun},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gen_optim-0.1.1.tar.gz (5.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gen_optim-0.1.1-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file gen_optim-0.1.1.tar.gz.

File metadata

  • Download URL: gen_optim-0.1.1.tar.gz
  • Upload date:
  • Size: 5.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.5

File hashes

Hashes for gen_optim-0.1.1.tar.gz
Algorithm Hash digest
SHA256 77d98963c1c6c52d47130bef616283155b7a55d401a174dbf8d8277ab0151ab3
MD5 accd99f5f39dd93a96a0392cce0b6acf
BLAKE2b-256 daac16c140e5ba2587d5de7999309d3b53649842c506cfe54783257b92ec22a6

See more details on using hashes here.

File details

Details for the file gen_optim-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: gen_optim-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 6.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.5

File hashes

Hashes for gen_optim-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5a61e6d194cf151ae504d5c1f609bb16ec816526ec11740c9d1bdc998456b861
MD5 30d16126e799f812283cab3d75a448c1
BLAKE2b-256 026202030d5e33c800bd7d07530d94afe8c2c6f7c1721b70e0463e52a9f838f2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page