Efficient optimizers
Project description
HeavyBall
[!IMPORTANT]
The SOAP implementation was broken until 0.9.0. Please upgrade to 0.9.0 or later.
A simple package of efficient optimizers
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
largely static alternative to torch.optim
with more and better optimizers.
Currently (2024-11-12, 0.11.0), the recommended optimizer is PrecondSchedulePaLMForeachSOAP
.
Features
- Stochastic Rounding: FP32 convergence with BF16 parameters
- Inplace EMA: Same math, but less memory, less compute and higher stability
- Foreach: Fast multi-tensor application
- PaLM Beta2: Fast initial convergence, stable late convergence
- ScheduleFree: No learning rate schedule, but better convergence
- Preconditioner Schedule: Improved loss-per-step in early convergence, better step-per-second in late convergence (explained below)
Getting started
pip install heavyball
import torch
import heavyball
# Create a model
model = torch.nn.Linear(16, 1)
# Create an optimizer
optimizer = heavyball.PrecondSchedulePaLMForeachSOAP(model.parameters(), lr=1e-3)
x = torch.randn(128, 16)
y = torch.randn(128, 1)
for _ in range(1000):
optimizer.zero_grad()
loss = torch.nn.functional.mse_loss(model(x), y)
loss.backward()
optimizer.step()
Optimizers
Name | Description | Advantages / Disadvantages |
---|---|---|
ForeachAdamW | More efficient (speed, memory) AdamW | + Faster than AdamW + Possibly more (numerically) stable |
ForeachLaProp | More efficient (speed, memory) LaProp | + Same cost as AdamW + Marginally better converence (better proofs) + Higher hyperparameter stability - Not a guaranteed win (can be neutral) - No "Slingshot" |
ForeachADOPT | More efficient (speed, memory) ADOPT | + Same cost as AdamW + Rigorous mathematical convergence proofs, even for challenging models (GANs) - Empirically underperforms LaProp - no bf16 |
ForeachSFAdamW | More efficient (speed, memory) ScheduleFree AdamW | + Same cost as AdamW, but better eval perf + Full control over hyperparameters |
PaLMForeachSFAdamW | ForeachSFAdamW with PaLM's beta2 schedule | + Same cost as AdamW, but better eval perf + Less control, but faster early and more stable late convergence + ScheduleFree - slow early convergence |
ForeachSOAP | More efficient (speed, memory) SOAP | + Faster convergence (loss-at-step) + Full control over hyperparameters - more memory usage - more hyperparameters - higher overhead than AdamW (can be ammortized; better loss-at-second) |
PaLMForeachSOAP | ForeachSOAP with PaLM's beta2 schedule | + Faster convergence (loss-at-step) + Less control, but faster early and more stable late convergence - more memory usage - more hyperparameters - higher overhead than AdamW (can be ammortized; better loss-at-second) |
SFPaLMForeachSOAP | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step) + less memory usage than PaLMForeachSOAP (more tham AdamW) - slower initial convergence than PaLMForeachSOAP (but allows higher LRs) - higher overhead than AdamW (can be ammortized) |
PrecondScheduleSFPaLMForeachSOAP | SFPaLMForeachSOAP with preconditioner schedule, matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP + Significantly faster (sec/it) later + less memory usage than PaLMForeachSOAP (more tham AdamW) - slower initial convergence than PaLMForeachSOAP (but allows higher LRs) - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
PrecondSchedulePaLMForeachSOAP | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence + Significantly faster (sec/it) later + high stability - more memory usage than PrecondScheduleSFPaLMForeachSOAP - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
PrecondScheduleForeachSOAP | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence + Significantly faster (sec/it) later - more memory usage than PrecondScheduleSFPaLMForeachSOAP - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
Precond Schedule
The default preconditioner schedule (f
) would yield the following update intervals:
Steps | Interval, f |
Total (schedule) | Total (constant, every 2) | Total (constant, every 16) |
---|---|---|---|---|
10 | 1.00005 | 10 | 5 (0.5x) | 0 (0.0x) |
100 | 1.026 | 99 | 50 (0.5x) | 6 (0.1x) |
1,000 | 2.0 | 738 | 500 (0.7x) | 62 (0.1x) |
10,000 | 14.3 | 2,168 | 5,000 (2.3x) | 625 (0.3x) |
100,000 | 100.2 | 4,049 | 50,000 (12.3x) | 6,250 (1.5x) |
1,000,000 | 513 | 7,245 | 500,000 (69.0x) | 62,500 (8.6x) |
Utils
To access heavyball.utils
, you need to explicitly import heavyball.utils
.
It has several handy functions:
set_torch()
sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)compile_mode
, a string passed as-is totorch.compile(mode=compile_mode)
in all compiled heavyball callszeroth_power_mode
, a string determining whether to use QR, newtonschulz{iterations}, or svd or eigh to approximate the eigenvectors. Eigh has the highest precision and cost
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file heavyball-0.11.0.tar.gz
.
File metadata
- Download URL: heavyball-0.11.0.tar.gz
- Upload date:
- Size: 23.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0e1e51ce5e9109385b0f16b39fe40d3fcd92a5c7b5e9cb529811988a9a44fe61 |
|
MD5 | 8a635d2071d991e20cafab8ea7c4bccd |
|
BLAKE2b-256 | 5561560c303962099f59a2dae0d4e032c3724c47b6df3b27bd2a8fdc04984d3e |
File details
Details for the file heavyball-0.11.0-py3-none-any.whl
.
File metadata
- Download URL: heavyball-0.11.0-py3-none-any.whl
- Upload date:
- Size: 36.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 14e7da83d82c3b298e132ab0a100caabd08730f932560e0b400434e8985f2a7c |
|
MD5 | 68163014ea3d264ebcfe11ec7528569b |
|
BLAKE2b-256 | ede4036f9695cacf67fbc4e008c9d4137e3707610442700cfd9b4560f93ef9c1 |