Skip to main content

BFloat16 Fused Adam Optimizer

Project description

BFloat16 Fused Optimizer

A mixed-precision optimizer to solve the stale weights problem of bfloat16 training.

When training models using bfloat16 optimizer, updates might often be cancelled if it's small compared to weight in magnitude, leading to the stale weights problem, which significantly hurt performance.

Utilizing the fact that the round-towards-zero (RTZ) result of a float32 to bfloat16 is the high 16 bits, this optimizer stores an extra 16-bit weights mantissa, acting as 16+16 optimizer, which is mathematically equivalent to storing an extra 32-bit master weight, solving the stale weights problem while only costs 25% more memory.

Usage

Drop-in replacement of torch.optim.AdamW. All parameters need to be in bfloat16.

  • Doesn't support foreach, fused argument, as the optimizer is already fused
  • Doesn't support amsgrad, maximize, capturable, differentiable argument yet
pip install bf16_fused_adam
from bf16_fused_adam import BF16FusedAdamW

# All supported arguments are listed below
optim = BF16FusedAdamW(model.parameters(),
    lr=1e-3,
    weight_decay=0.1,
    betas=(0.9, 0.95),
    eps=1e-5,
)

Details

AdamW Reference States (PyTorch FusedAdamW):

  • param (bf16)
  • grad (bf16)
  • exp_avg (bf16)
  • exp_avg_sq (bf16)

16+16 Optimizer States (BF16FusedAdamW):

  • param (bf16, high 16 bits of master fp32 weights)
  • mantissa (uint16, low 16 bits of master fp32 weights)
  • grad (bf16)
  • exp_avg (bf16)
  • exp_avg_sq (bf16)
Master weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 16)   = 32bit
               [             param 16           ] [mantissa 16]   = 32bit

TODO

  • Stochastic rounding (trading precision for memory)
  • 16+8 optimizer (saving more memory)
Master weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 8) (mantissa 8)   = 32bit
              [             param 16           ] [mantissa 8] [dropped 8]    = 24bit

Consistency Tests

We tested the consistency against reference AdamW implementation. To run tests, clone this repository, run pytest:

pip install -e .
pytest

Passed

  • H100
  • A100
  • RTX 4090 [TBD]
  • RTX 3090 [TBD]

References

16+16 optimizer:

PyTorch AdamW:

Gopher:

Project details


Release history Release notifications | RSS feed

This version

0.1

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bf16_fused_adam-0.1.tar.gz (13.5 kB view details)

Uploaded Source

File details

Details for the file bf16_fused_adam-0.1.tar.gz.

File metadata

  • Download URL: bf16_fused_adam-0.1.tar.gz
  • Upload date:
  • Size: 13.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for bf16_fused_adam-0.1.tar.gz
Algorithm Hash digest
SHA256 393ef40b422dc0cb8002c57c9af08cdc5ad11607186119be95acd964b892885d
MD5 43790e87d6872f785098c3517689910f
BLAKE2b-256 626232b5d462a9af4ad59082160fb6bf815d0cbdaeed0a8546461cbadb30d8b1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page