Skip to main content

AdamW Optimizer for bfloat16

Project description

AdamW optimizer for bfloat16 in PyTorch

This is a version of the AdamW optimizer for use in torch that achieves the same results in ViT training tests as training with the weights in float32 with operations in float32 or bfloat16 (autocast). By keeping your weights in bfloat16, you can save approximately half the weights they would normally take up in memory. It uses stochastic rounding and a correction term to achieve this.

There is a small (~10-20%) performance hit depending on your hardware.

To use:

from adamw_bf16 import AdamWBF16

model = model.to(dtype=torch.bfloat16)
optimizer = AdamWBF16(model.parameters(), ...)

# Train your model

This repository was created using code from the following two projects. It was found that insights from both could be combined to match the performance with the model weights stored in float32.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

adamw_bf16-0.0.1.tar.gz (16.3 kB view details)

Uploaded Source

Built Distribution

adamw_bf16-0.0.1-py3-none-any.whl (16.7 kB view details)

Uploaded Python 3

File details

Details for the file adamw_bf16-0.0.1.tar.gz.

File metadata

  • Download URL: adamw_bf16-0.0.1.tar.gz
  • Upload date:
  • Size: 16.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for adamw_bf16-0.0.1.tar.gz
Algorithm Hash digest
SHA256 ca690f29d669f60cc16c56a02f1ec0aa7c480a0563d6492ec3d83469a493e97c
MD5 1b57b0264d6f338ba5a04f1f82820983
BLAKE2b-256 8b5389d51a5347af8746db26f21cda31db98411eac8b04d64d4e15670baed794

See more details on using hashes here.

File details

Details for the file adamw_bf16-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: adamw_bf16-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 16.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for adamw_bf16-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a185ef44d47c00b591de12359d561f53379e124208aae3f3bd22744e7dbb4d22
MD5 416a9b26c0f99c83d1325ce7173a3d11
BLAKE2b-256 1288f9efb9201e6458a2d53bc485872b35d1ba016ba1e104a43aa12f6634d318

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page