Skip to main content

Efficient optimizers

Project description

HeavyBall

[!IMPORTANT]
It's recommended to use heavyball.utils.set_torch() for faster training and less memory usage.

A simple package of efficient optimizers

The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple largely static alternative to torch.optim with more and better optimizers.

Currently (2024-11-22, 0.19), the recommended stable optimizer is PrecondSchedulePaLMSOAP (see below). The recommended experimental optimizer is DelayedPSGDKron (tuning guide).

Features

  • Stochastic Rounding: FP32 convergence with BF16 parameters
  • Inplace EMA: Same math, but less memory, less compute and higher stability
  • Foreach: Fast multi-tensor application (turn it off to save memory via foreach=False)
  • PaLM Beta2: Fast initial convergence, stable late convergence
  • ScheduleFree: No learning rate schedule, but better convergence
  • Preconditioner Schedule: Improved loss-per-step in early convergence, better step-per-second in late convergence (explained below)
  • Memory-efficient storage PSGD supports store_triu_as_line (default: True) and q_dtype to trade off memory usage for memory bandwidth; Other optimizers have storage_dtype, supporting lower-precision EMAs at no(?) performance drop via stochastic rounding

Getting started

pip install heavyball
import torch
import heavyball

# Create a model
model = torch.nn.Linear(16, 1)

# Create an optimizer
optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)

x = torch.randn(128, 16)
y = torch.randn(128, 1)

for _ in range(1000):
    optimizer.zero_grad()
    loss = torch.nn.functional.mse_loss(model(x), y)
    loss.backward()
    optimizer.step()

Optimizers

Name Description Advantages / Disadvantages
AdamW More efficient (speed, memory) AdamW + Faster than AdamW
+ Possibly more (numerically) stable
LaProp More efficient (speed, memory) LaProp + Same cost as AdamW
+ Marginally better converence (better proofs)
+ Higher hyperparameter stability
- Not a guaranteed win (can be neutral)
- No "Slingshot"
ADOPT More efficient (speed, memory) ADOPT + Same cost as AdamW
+ Rigorous mathematical convergence proofs, even for challenging models (GANs)
- Empirically underperforms LaProp
- no bf16
SFAdamW More efficient (speed, memory) ScheduleFree AdamW + Same cost as AdamW, but better eval perf
+ Full control over hyperparameters
PaLMSFAdamW ForeachSFAdamW with PaLM's beta2 schedule + Same cost as AdamW, but better eval perf
+ Less control, but faster early and more stable late convergence
+ ScheduleFree
- slow early convergence
SOAP More efficient (speed, memory) SOAP + Faster convergence (loss-at-step)
+ Full control over hyperparameters
- more memory usage
- more hyperparameters
- higher overhead than AdamW (can be ammortized; better loss-at-second)
PaLMSOAP ForeachSOAP with PaLM's beta2 schedule + Faster convergence (loss-at-step)
+ Less control, but faster early and more stable late convergence
- more memory usage
- more hyperparameters
- higher overhead than AdamW (can be ammortized; better loss-at-second)
SFPaLMSOAP ScheduleFree PaLMForeachSOAP + Fast convergence (loss-at-step)
+ less memory usage than PaLMForeachSOAP (more tham AdamW)
- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)
- higher overhead than AdamW (can be ammortized)
PrecondScheduleSFPaLMSOAP SFPaLMForeachSOAP with preconditioner schedule, matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 + Better initial convergence than SFPaLMForeachSOAP
+ Significantly faster (sec/it) later
+ less memory usage than PaLMForeachSOAP (more tham AdamW)
- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)
- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step
PrecondSchedulePaLMSOAP PrecondScheduleSFPaLMForeachSOAP without schedule-free + Best initial convergence
+ Significantly faster (sec/it) later
+ high stability
- more memory usage than PrecondScheduleSFPaLMForeachSOAP
- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps
PrecondScheduleSOAP PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule + Better initial convergence
+ Significantly faster (sec/it) later
- more memory usage than PrecondScheduleSFPaLMForeachSOAP
- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps

Precond Schedule

The default preconditioner schedule (f) would yield the following update intervals:

Steps Interval, f Total (schedule) Total (constant, every 2) Total (constant, every 16)
10 1.00005 10 5 (0.5x) 0 (0.0x)
100 1.026 99 50 (0.5x) 6 (0.1x)
1,000 2.0 738 500 (0.7x) 62 (0.1x)
10,000 14.3 2,168 5,000 (2.3x) 625 (0.3x)
100,000 100.2 4,049 50,000 (12.3x) 6,250 (1.5x)
1,000,000 513 7,245 500,000 (69.0x) 62,500 (8.6x)

Memory

Second order optimizers make it difficult to estimate memory usage, as it depends on shapes and hyperparameters. To estimate your memory usage, you may use test/test_memory.py which attempts to ensure there are no regressions.
Furthermore, you can find real-world memory usage of a 300M parameters video diffusion model below: img.png

PSGD

HeavyBall offers various configurations of PSGD:

  • "PSGDKron" is the baseline, equivalent to kron_torch, but with lower compute and memory overhead.
  • "PurePSGD" has no momentum, further reducing memory and compute
  • "DelayedPSGD" implements SOAP/ADOPT-style off-by-one momentum, which has worse initial convergence but higher stability img.png

Utils

To access heavyball.utils, you need to explicitly import heavyball.utils.
It has several handy functions:

  • set_torch() sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)
  • compile_mode, a string passed as-is to torch.compile(mode=compile_mode) in all compiled heavyball calls
  • zeroth_power_mode, a string determining whether to use QR, newtonschulz{iterations}, or svd or eigh to approximate the eigenvectors. Eigh has the highest precision and cost

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

heavyball-0.19.0.tar.gz (32.3 kB view details)

Uploaded Source

Built Distribution

heavyball-0.19.0-py3-none-any.whl (48.1 kB view details)

Uploaded Python 3

File details

Details for the file heavyball-0.19.0.tar.gz.

File metadata

  • Download URL: heavyball-0.19.0.tar.gz
  • Upload date:
  • Size: 32.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for heavyball-0.19.0.tar.gz
Algorithm Hash digest
SHA256 ee9545022a406685d7db752c1d25025024a5a362acb184b00b978c24ea7e8ff3
MD5 ed2365bf3892a1e0001a84410447e1d7
BLAKE2b-256 59b23dcdaabf994f12d436cb8b3f4a14585846ad8b8cb47699b46d331a7798e0

See more details on using hashes here.

File details

Details for the file heavyball-0.19.0-py3-none-any.whl.

File metadata

  • Download URL: heavyball-0.19.0-py3-none-any.whl
  • Upload date:
  • Size: 48.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for heavyball-0.19.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9754a4d0a8728ff5c7dda14024b69ed227713508003d4289e8e50eb2482b96d2
MD5 f72f4ab0349c4a998af57626e3fe94e5
BLAKE2b-256 fc2689ba8d5363b7bc6dbabb357d983b2f1e6689faba5112cc4f989511ddb898

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page