AdamW Optimizer for bfloat16
Project description
AdamW optimizer for bfloat16 in PyTorch
This is a version of the AdamW optimizer for use in torch that achieves the same results in ViT training tests as training with the weights in float32 with operations in float32 or bfloat16 (autocast). By keeping your weights in bfloat16, you can save approximately half the weights they would normally take up in memory. It uses stochastic rounding and a correction term to achieve this.
There is a small (~10-20%) performance hit depending on your hardware.
To use:
from adamw_bf16 import AdamWBF16
model = model.to(dtype=torch.bfloat16)
optimizer = AdamWBF16(model.parameters(), ...)
# Train your model
This repository was created using code from the following two projects. It was found that insights from both could be combined to match the performance with the model weights stored in float32.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file adamw_bf16-0.0.1.tar.gz
.
File metadata
- Download URL: adamw_bf16-0.0.1.tar.gz
- Upload date:
- Size: 16.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ca690f29d669f60cc16c56a02f1ec0aa7c480a0563d6492ec3d83469a493e97c |
|
MD5 | 1b57b0264d6f338ba5a04f1f82820983 |
|
BLAKE2b-256 | 8b5389d51a5347af8746db26f21cda31db98411eac8b04d64d4e15670baed794 |
File details
Details for the file adamw_bf16-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: adamw_bf16-0.0.1-py3-none-any.whl
- Upload date:
- Size: 16.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a185ef44d47c00b591de12359d561f53379e124208aae3f3bd22744e7dbb4d22 |
|
MD5 | 416a9b26c0f99c83d1325ce7173a3d11 |
|
BLAKE2b-256 | 1288f9efb9201e6458a2d53bc485872b35d1ba016ba1e104a43aa12f6634d318 |