Efficient optimizers
Project description
HeavyBall
[!IMPORTANT]
It's recommended to useheavyball.utils.set_torch()
for faster training and less memory usage.
A simple package of efficient optimizers
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
largely static alternative to torch.optim
with more and better optimizers.
Currently (2024-11-20, 0.17.0), the recommended stable optimizer is PrecondSchedulePaLMSOAP
(see below). The
recommended experimental optimizer is DelayedPSGDKron
(tuning guide).
Features
- Stochastic Rounding: FP32 convergence with BF16 parameters
- Inplace EMA: Same math, but less memory, less compute and higher stability
- Foreach: Fast multi-tensor application (turn it off to save memory via
foreach=False
) - PaLM Beta2: Fast initial convergence, stable late convergence
- ScheduleFree: No learning rate schedule, but better convergence
- Preconditioner Schedule: Improved loss-per-step in early convergence, better step-per-second in late convergence (explained below)
- Memory-efficient storage PSGD supports
store_triu_as_line
(default:True
) to trade off memory usage for memory bandwidth; turn it off for lower overheads (for more, see PSGD Efficiency)
Getting started
pip install heavyball
import torch
import heavyball
# Create a model
model = torch.nn.Linear(16, 1)
# Create an optimizer
optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
x = torch.randn(128, 16)
y = torch.randn(128, 1)
for _ in range(1000):
optimizer.zero_grad()
loss = torch.nn.functional.mse_loss(model(x), y)
loss.backward()
optimizer.step()
Optimizers
Name | Description | Advantages / Disadvantages |
---|---|---|
AdamW | More efficient (speed, memory) AdamW | + Faster than AdamW + Possibly more (numerically) stable |
LaProp | More efficient (speed, memory) LaProp | + Same cost as AdamW + Marginally better converence (better proofs) + Higher hyperparameter stability - Not a guaranteed win (can be neutral) - No "Slingshot" |
ADOPT | More efficient (speed, memory) ADOPT | + Same cost as AdamW + Rigorous mathematical convergence proofs, even for challenging models (GANs) - Empirically underperforms LaProp - no bf16 |
SFAdamW | More efficient (speed, memory) ScheduleFree AdamW | + Same cost as AdamW, but better eval perf + Full control over hyperparameters |
PaLMSFAdamW | ForeachSFAdamW with PaLM's beta2 schedule | + Same cost as AdamW, but better eval perf + Less control, but faster early and more stable late convergence + ScheduleFree - slow early convergence |
SOAP | More efficient (speed, memory) SOAP | + Faster convergence (loss-at-step) + Full control over hyperparameters - more memory usage - more hyperparameters - higher overhead than AdamW (can be ammortized; better loss-at-second) |
PaLMSOAP | ForeachSOAP with PaLM's beta2 schedule | + Faster convergence (loss-at-step) + Less control, but faster early and more stable late convergence - more memory usage - more hyperparameters - higher overhead than AdamW (can be ammortized; better loss-at-second) |
SFPaLMSOAP | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step) + less memory usage than PaLMForeachSOAP (more tham AdamW) - slower initial convergence than PaLMForeachSOAP (but allows higher LRs) - higher overhead than AdamW (can be ammortized) |
PrecondScheduleSFPaLMSOAP | SFPaLMForeachSOAP with preconditioner schedule, matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP + Significantly faster (sec/it) later + less memory usage than PaLMForeachSOAP (more tham AdamW) - slower initial convergence than PaLMForeachSOAP (but allows higher LRs) - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
PrecondSchedulePaLMSOAP | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence + Significantly faster (sec/it) later + high stability - more memory usage than PrecondScheduleSFPaLMForeachSOAP - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
PrecondScheduleSOAP | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence + Significantly faster (sec/it) later - more memory usage than PrecondScheduleSFPaLMForeachSOAP - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
Precond Schedule
The default preconditioner schedule (f
) would yield the following update intervals:
Steps | Interval, f |
Total (schedule) | Total (constant, every 2) | Total (constant, every 16) |
---|---|---|---|---|
10 | 1.00005 | 10 | 5 (0.5x) | 0 (0.0x) |
100 | 1.026 | 99 | 50 (0.5x) | 6 (0.1x) |
1,000 | 2.0 | 738 | 500 (0.7x) | 62 (0.1x) |
10,000 | 14.3 | 2,168 | 5,000 (2.3x) | 625 (0.3x) |
100,000 | 100.2 | 4,049 | 50,000 (12.3x) | 6,250 (1.5x) |
1,000,000 | 513 | 7,245 | 500,000 (69.0x) | 62,500 (8.6x) |
Memory
Second order optimizers make it difficult to estimate memory usage, as it depends on shapes and hyperparameters. To
estimate your memory usage, you may use test/test_memory.py
which attempts to ensure there are no regressions.
Furthermore, you can find real-world memory usage of a 300M parameters video diffusion model below:
PSGD
HeavyBall offers various configurations of PSGD:
- "PSGDKron" is the baseline, equivalent to kron_torch, but with lower compute and memory overhead.
- "PurePSGD" has no momentum, further reducing memory and compute
- "DelayedPSGD" implements SOAP/ADOPT-style off-by-one momentum, which has worse initial convergence but higher stability
Utils
To access heavyball.utils
, you need to explicitly import heavyball.utils
.
It has several handy functions:
set_torch()
sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)compile_mode
, a string passed as-is totorch.compile(mode=compile_mode)
in all compiled heavyball callszeroth_power_mode
, a string determining whether to use QR, newtonschulz{iterations}, or svd or eigh to approximate the eigenvectors. Eigh has the highest precision and cost
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file heavyball-0.18.2.tar.gz
.
File metadata
- Download URL: heavyball-0.18.2.tar.gz
- Upload date:
- Size: 31.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9129c63157d79833a7ceb5d011a8c4eef36185c9907a9c939778566e1abd2083 |
|
MD5 | 9861c71634a1e50db8a503b214a0b7bf |
|
BLAKE2b-256 | 3b80daf7a14aa3feaddca3ade914d50c9afa1bffa5d848abce382b5b778d631d |
File details
Details for the file heavyball-0.18.2-py3-none-any.whl
.
File metadata
- Download URL: heavyball-0.18.2-py3-none-any.whl
- Upload date:
- Size: 46.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6f68e7c6e36e2d0c2862b42eae4af76490df4b5314a00cb94cd1a2e3cbe35915 |
|
MD5 | a6b93d6d3d90131ab110286156340ed7 |
|
BLAKE2b-256 | d12d30989e7ea56e02a2d301870786ebda5c511a7cac33999fa160b65840fd55 |