Skip to main content

Efficient optimizers

Project description

HeavyBall

A simple package of efficient optimizers

The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple largely static alternative to torch.optim with more and better optimizers.

Currently (2024-11-08, 0.7.2), the recommended optimizer is PrecondSchedulePaLMForeachSOAP.

Features

  • Stochastic Rounding: FP32 convergence with BF16 parameters
  • Inplace EMA: Same math, but less memory, less compute and higher stability
  • Foreach: Fast multi-tensor application
  • PaLM Beta2: Fast initial convergence, stable late convergence
  • ScheduleFree: No learning rate schedule, but better convergence

Getting started

pip install heavyball
import torch
import heavyball

# Create a model
model = torch.nn.Linear(16, 1)

# Create an optimizer
optimizer = heavyball.PrecondSchedulePaLMForeachSOAP(model.parameters(), lr=1e-3)

x = torch.randn(128, 16)
y = torch.randn(128, 1)

for _ in range(1000):
    optimizer.zero_grad()
    loss = torch.nn.functional.mse_loss(model(x), y)
    loss.backward()
    optimizer.step()

Optimizers

Name Description Advantages / Disadvantages
ForeachAdamW More efficient (speed, memory) AdamW + Faster than AdamW
+ Possibly more (numerically) stable
ForeachLaProp More efficient (speed, memory) LaProp + Same cost as AdamW
+ Marginally better converence (better proofs)
+ Higher hyperparameter stability
- Not a guaranteed win (can be neutral)
- No "Slingshot"
ForeachADOPT More efficient (speed, memory) ADOPT + Same cost as AdamW
+ Rigorous mathematical convergence proofs, even for challenging models (GANs)
- Empirically underperforms LaProp
- no bf16
ForeachSFAdamW More efficient (speed, memory) ScheduleFree AdamW + Same cost as AdamW, but better eval perf
+ Full control over hyperparameters
PaLMForeachSFAdamW ForeachSFAdamW with PaLM's beta2 schedule + Same cost as AdamW, but better eval perf
+ Less control, but faster early and more stable late convergence
+ ScheduleFree
- slow early convergence
ForeachSOAP More efficient (speed, memory) SOAP + Faster convergence (loss-at-step)
+ Full control over hyperparameters
- more memory usage
- more hyperparameters
- higher overhead than AdamW (can be ammortized; better loss-at-second)
PaLMForeachSOAP ForeachSOAP with PaLM's beta2 schedule + Faster convergence (loss-at-step)
+ Less control, but faster early and more stable late convergence
- more memory usage
- more hyperparameters
- higher overhead than AdamW (can be ammortized; better loss-at-second)
SFPaLMForeachSOAP ScheduleFree PaLMForeachSOAP + Fast convergence (loss-at-step)
+ less memory usage than PaLMForeachSOAP (more tham AdamW)
- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)
- higher overhead than AdamW (can be ammortized)
PrecondScheduleSFPaLMForeachSOAP SFPaLMForeachSOAP with preconditioner schedule, matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 + Better initial convergence than SFPaLMForeachSOAP
+ Significantly faster (sec/it) later
+ less memory usage than PaLMForeachSOAP (more tham AdamW)
- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)
- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step
PrecondSchedulePaLMForeachSOAP PrecondScheduleSFPaLMForeachSOAP without schedule-free + Best initial convergence
+ Significantly faster (sec/it) later
+ high stability
- more memory usage than PrecondScheduleSFPaLMForeachSOAP
- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps
PrecondScheduleForeachSOAP PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule + Better initial convergence
+ Significantly faster (sec/it) later
- more memory usage than PrecondScheduleSFPaLMForeachSOAP
- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps

Precond Schedule

The default preconditioner schedule (f) would yield the following update intervals:

Steps Interval, f Total (schedule) Total (constant, every 2) Total (constant, every 16)
10 1.00005 10 5 (0.5x) 0 (0.0x)
100 1.026 99 50 (0.5x) 6 (0.1x)
1,000 2.0 738 500 (0.7x) 62 (0.1x)
10,000 14.3 2,168 5,000 (2.3x) 625 (0.3x)
100,000 100.2 4,049 50,000 (12.3x) 6,250 (1.5x)
1,000,000 513 7,245 500,000 (69.0x) 62,500 (8.6x)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

heavyball-0.7.3.tar.gz (14.6 kB view details)

Uploaded Source

Built Distribution

heavyball-0.7.3-py3-none-any.whl (29.4 kB view details)

Uploaded Python 3

File details

Details for the file heavyball-0.7.3.tar.gz.

File metadata

  • Download URL: heavyball-0.7.3.tar.gz
  • Upload date:
  • Size: 14.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for heavyball-0.7.3.tar.gz
Algorithm Hash digest
SHA256 441f6f6d6ddc61915d5f4dbf359ab13b20d9fbafe18751be2d216f91595c2810
MD5 b275826927cbd67adc5cd3000e3d3168
BLAKE2b-256 33f86adaabef94a967b4cb2b6f01b0f014341ede5f9c594d55f50cb1c60d8775

See more details on using hashes here.

File details

Details for the file heavyball-0.7.3-py3-none-any.whl.

File metadata

  • Download URL: heavyball-0.7.3-py3-none-any.whl
  • Upload date:
  • Size: 29.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for heavyball-0.7.3-py3-none-any.whl
Algorithm Hash digest
SHA256 3dc86607f0c8ae429c50ad870b04e1225f24fc73e1956a4f286ee4b549811238
MD5 d142729d24351ee2cd889054958844b3
BLAKE2b-256 eef83a11824d052588f609dcc1f9e495abad12df6a7f4bcfceb85619f4fe6b60

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page