BFloat16 Fused Adam Optimizer
Project description
BFloat16 Fused Optimizer
A mixed-precision optimizer to solve the stale weights problem of bfloat16 training.
When training models using bfloat16
optimizer, updates might often be cancelled if it's small compared to weight in magnitude, leading to the stale weights problem, which significantly hurt performance.
Utilizing the fact that the round-towards-zero (RTZ) result of a float32
to bfloat16
is the high 16 bits, this optimizer stores an extra 16-bit weights mantissa, acting as 16+16 optimizer, which is mathematically equivalent to storing an extra 32-bit master weight, solving the stale weights problem while only costs 25% more memory.
Usage
Drop-in replacement of torch.optim.AdamW
. All parameters need to be in bfloat16
.
- Doesn't support
foreach
,fused
argument, as the optimizer is already fused - Doesn't support
amsgrad
,maximize
,capturable
,differentiable
argument yet
pip install bf16_fused_adam
from bf16_fused_adam import BF16FusedAdamW
# All supported arguments are listed below
optim = BF16FusedAdamW(model.parameters(),
lr=1e-3,
weight_decay=0.1,
betas=(0.9, 0.95),
eps=1e-5,
)
Details
AdamW Reference States (PyTorch FusedAdamW):
- param (bf16)
- grad (bf16)
- exp_avg (bf16)
- exp_avg_sq (bf16)
16+16 Optimizer States (BF16FusedAdamW):
- param (bf16, high 16 bits of master fp32 weights)
- mantissa (uint16, low 16 bits of master fp32 weights)
- grad (bf16)
- exp_avg (bf16)
- exp_avg_sq (bf16)
Master weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 16) = 32bit
[ param 16 ] [mantissa 16] = 32bit
TODO
- Stochastic rounding (trading precision for memory)
- 16+8 optimizer (saving more memory)
Master weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 8) (mantissa 8) = 32bit
[ param 16 ] [mantissa 8] [dropped 8] = 24bit
Consistency Tests
We tested the consistency against reference AdamW implementation. To run tests, clone this repository, run pytest:
pip install -e .
pytest
Passed
- H100
- A100
- RTX 4090 [TBD]
- RTX 3090 [TBD]
References
16+16 optimizer:
PyTorch AdamW:
- https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/fused_adam_utils.cuh
- https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/FusedAdamWKernel.cu
Gopher:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file bf16_fused_adam-0.1.tar.gz
.
File metadata
- Download URL: bf16_fused_adam-0.1.tar.gz
- Upload date:
- Size: 13.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 393ef40b422dc0cb8002c57c9af08cdc5ad11607186119be95acd964b892885d |
|
MD5 | 43790e87d6872f785098c3517689910f |
|
BLAKE2b-256 | 626232b5d462a9af4ad59082160fb6bf815d0cbdaeed0a8546461cbadb30d8b1 |