Efficient optimizers
Project description
HeavyBall
A simple package of efficient optimizers
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
largely static alternative to torch.optim
with more and better optimizers.
Currently (2024-11-08, 0.7.4), the recommended optimizer is PrecondSchedulePaLMForeachSOAP
.
Features
- Stochastic Rounding: FP32 convergence with BF16 parameters
- Inplace EMA: Same math, but less memory, less compute and higher stability
- Foreach: Fast multi-tensor application
- PaLM Beta2: Fast initial convergence, stable late convergence
- ScheduleFree: No learning rate schedule, but better convergence
- Preconditioner Schedule: Improved loss-per-step in early convergence, better step-per-second in late convergence (explained below)
Getting started
pip install heavyball
import torch
import heavyball
# Create a model
model = torch.nn.Linear(16, 1)
# Create an optimizer
optimizer = heavyball.PrecondSchedulePaLMForeachSOAP(model.parameters(), lr=1e-3)
x = torch.randn(128, 16)
y = torch.randn(128, 1)
for _ in range(1000):
optimizer.zero_grad()
loss = torch.nn.functional.mse_loss(model(x), y)
loss.backward()
optimizer.step()
Optimizers
Name | Description | Advantages / Disadvantages |
---|---|---|
ForeachAdamW | More efficient (speed, memory) AdamW | + Faster than AdamW + Possibly more (numerically) stable |
ForeachLaProp | More efficient (speed, memory) LaProp | + Same cost as AdamW + Marginally better converence (better proofs) + Higher hyperparameter stability - Not a guaranteed win (can be neutral) - No "Slingshot" |
ForeachADOPT | More efficient (speed, memory) ADOPT | + Same cost as AdamW + Rigorous mathematical convergence proofs, even for challenging models (GANs) - Empirically underperforms LaProp - no bf16 |
ForeachSFAdamW | More efficient (speed, memory) ScheduleFree AdamW | + Same cost as AdamW, but better eval perf + Full control over hyperparameters |
PaLMForeachSFAdamW | ForeachSFAdamW with PaLM's beta2 schedule | + Same cost as AdamW, but better eval perf + Less control, but faster early and more stable late convergence + ScheduleFree - slow early convergence |
ForeachSOAP | More efficient (speed, memory) SOAP | + Faster convergence (loss-at-step) + Full control over hyperparameters - more memory usage - more hyperparameters - higher overhead than AdamW (can be ammortized; better loss-at-second) |
PaLMForeachSOAP | ForeachSOAP with PaLM's beta2 schedule | + Faster convergence (loss-at-step) + Less control, but faster early and more stable late convergence - more memory usage - more hyperparameters - higher overhead than AdamW (can be ammortized; better loss-at-second) |
SFPaLMForeachSOAP | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step) + less memory usage than PaLMForeachSOAP (more tham AdamW) - slower initial convergence than PaLMForeachSOAP (but allows higher LRs) - higher overhead than AdamW (can be ammortized) |
PrecondScheduleSFPaLMForeachSOAP | SFPaLMForeachSOAP with preconditioner schedule, matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP + Significantly faster (sec/it) later + less memory usage than PaLMForeachSOAP (more tham AdamW) - slower initial convergence than PaLMForeachSOAP (but allows higher LRs) - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
PrecondSchedulePaLMForeachSOAP | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence + Significantly faster (sec/it) later + high stability - more memory usage than PrecondScheduleSFPaLMForeachSOAP - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
PrecondScheduleForeachSOAP | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence + Significantly faster (sec/it) later - more memory usage than PrecondScheduleSFPaLMForeachSOAP - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
Precond Schedule
The default preconditioner schedule (f
) would yield the following update intervals:
Steps | Interval, f |
Total (schedule) | Total (constant, every 2) | Total (constant, every 16) |
---|---|---|---|---|
10 | 1.00005 | 10 | 5 (0.5x) | 0 (0.0x) |
100 | 1.026 | 99 | 50 (0.5x) | 6 (0.1x) |
1,000 | 2.0 | 738 | 500 (0.7x) | 62 (0.1x) |
10,000 | 14.3 | 2,168 | 5,000 (2.3x) | 625 (0.3x) |
100,000 | 100.2 | 4,049 | 50,000 (12.3x) | 6,250 (1.5x) |
1,000,000 | 513 | 7,245 | 500,000 (69.0x) | 62,500 (8.6x) |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
heavyball-0.8.1.tar.gz
(14.8 kB
view details)
Built Distribution
heavyball-0.8.1-py3-none-any.whl
(29.0 kB
view details)
File details
Details for the file heavyball-0.8.1.tar.gz
.
File metadata
- Download URL: heavyball-0.8.1.tar.gz
- Upload date:
- Size: 14.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f9aade00afb95e81bd6de2b05070e61da50eac37c403cb848b9145a5f4c75d29 |
|
MD5 | da119644cfe363b7115fa10a08667376 |
|
BLAKE2b-256 | 24562c452c363d5235f4cb600d5d702e36bae5e74708c54fc7115709f699cbd3 |
File details
Details for the file heavyball-0.8.1-py3-none-any.whl
.
File metadata
- Download URL: heavyball-0.8.1-py3-none-any.whl
- Upload date:
- Size: 29.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0f618ccac1b7f7d4bb8267302501abddd86d14782a31408ce035d7a6475474a9 |
|
MD5 | ba9cf26decd798f20df28dbd4f98b3e1 |
|
BLAKE2b-256 | 2fe150308c043731b5aa3740819b09562dca05ea2677ed98c3fdffe624730cab |