Skip to main content

Stochastic bfloat16 based optimizer library.

Project description

torchastic

stochastic bfloat16 based optimizer library

Stochastic Optimizer: Reduced Memory Consumption with BF16 Training

Key Feature: Drastically Lower Memory Requirements

The Stochastic Optimizer is designed to significantly reduce memory consumption by training entirely in BF16 (bfloat16) precision. Unlike traditional FP32 (32-bit floating point) or mixed precision training, which can still require large memory resources, the Stochastic Optimizer reduces the memory footprint across parameters, gradients, and optimizer states.

By leveraging BF16 for training, the Stochastic Optimizer reduces memory consumption by 50%, making it ideal for training larger models or increasing batch sizes within the same memory budget.

Memory Comparison: Adam, Adam8bit, and Stochastic Adam

Traditional Adam (FP32)

  • Parameter storage: 4 bytes per parameter
  • Gradient storage: 4 bytes per gradient
  • State 1 (momentum): 4 bytes per parameter
  • State 2 (variance): 4 bytes per parameter

Total memory required:
4 + 4 + 4 + 4 = 16 bytes per parameter

Adam8bit (Mixed Precision)

  • Parameter storage: 4 bytes per parameter
  • Gradient storage: 4 bytes per gradient
  • State 1 (momentum): 1 byte per parameter (quantized)
  • State 2 (variance): 1 byte per parameter (quantized)

Total memory required:
4 + 4 + 1 + 1 = 10 bytes per parameter

Stochastic Adam (BF16)

In contrast, Stochastic Adam optimizes memory usage by storing everything in BF16:

  • Parameter storage: 2 bytes per parameter
  • Gradient storage: 2 bytes per gradient
  • State 1 (momentum): 2 bytes per parameter
  • State 2 (variance): 2 bytes per parameter

Total memory required:
2 + 2 + 2 + 2 = 8 bytes per parameter

Summary of Memory Savings

  • Traditional Adam: 16 bytes per parameter
  • Adam8bit: 10 bytes per parameter (37.5% reduction over Adam)
  • Stochastic Adam: 8 bytes per parameter (50% reduction over Adam, 20% reduction over Adam8bit)

With Stochastic Adam, you save an additional 20% of memory compared to Adam8bit, while maintaining the precision needed for training stability and accuracy.

Why BF16?

BF16 (bfloat16) is advantageous for deep learning because it provides the same dynamic range as FP32, while using fewer bits for the mantissa (7 bits in BF16 vs. 23 bits in FP32). This allows models to represent a wide range of values while using half the memory of FP32. However, the reduced precision of the mantissa can sometimes lead to stale gradients, especially during long accumulation phases like weight updates, where small updates can become too insignificant to register in BF16.

Stochastic BF16: Solving Stale Gradients

A key innovation in Stochastic BF16 is the use of stochastic rounding when casting from FP32 to BF16. Stochastic rounding ensures that even very small updates, which might be lost due to BF16's reduced precision, are probabilistically rounded up or down based on the FP32 value. This prevents the common issue of stale gradients where updates become too small to affect the model during long training accumulations. (Revisiting BFloat16 Training)

Thanks to @Nerogar for fast stochastic rounding pytorch implementation!

Benefits of Stochastic Rounding:

  • Prevents stale gradients: Small updates that could otherwise be lost are preserved, ensuring more accurate weight updates over time.
  • Improved training stability: Stochastic rounding is particularly useful during weight updates, where stable accumulation is critical to maintaining convergence and preventing stalling in training.

Conclusion

The Stochastic Optimizer provides a more memory-efficient alternative to both Adam and Adam8bit by training entirely in BF16 precision. It reduces the memory footprint by 50% compared to FP32 Adam, and by 20% compared to Adam8bit, while also preventing stale gradients through stochastic rounding. This makes the Stochastic Optimizer an excellent choice for scaling up models and improving efficiency in resource-constrained environments, all without sacrificing the quality of your training process.

How to Install

pip install torchastic

Build from Source

git clone https://github.com/lodestone-rock/torchastic/
cd torchastic
python setup.py sdist bdist_wheel
pip install .

How to Use

import torch
import torch.nn as nn
from torchastic import AdamW, StochasticAccumulator


class Model(nn.Module):
    ...


# Init model
model = Model(*model_args)
model.to(torch.bfloat16)
optimizer = AdamW(model.parameters(), lr=0.01, weight_decay=1e-2)

# Apply stochastic grad accumulator hooks
StochasticAccumulator.assign_hooks(model)

# Training
while True:

    # Gradient accumulation
    for _ in range(grad_accum_length):
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            loss = model.loss(*model_input)
        loss.backward()

    # Apply grad buffer back
    StochasticAccumulator.reassign_grad_buffer(model)
    optimizer.step()
    optimizer.zero_grad()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchastic-0.1.2.tar.gz (11.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchastic-0.1.2-py3-none-any.whl (10.3 kB view details)

Uploaded Python 3

File details

Details for the file torchastic-0.1.2.tar.gz.

File metadata

  • Download URL: torchastic-0.1.2.tar.gz
  • Upload date:
  • Size: 11.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.6

File hashes

Hashes for torchastic-0.1.2.tar.gz
Algorithm Hash digest
SHA256 2a4ed612246bdb9f8094171bc036b16dd9d6e565ca35758433c56e31a995e2a9
MD5 7a09a6cfd11ec5591816177d9f33e2cf
BLAKE2b-256 2df8d953507561d3352910a9823f73ab28bb9ccb3723f3e7f234c67697e60c09

See more details on using hashes here.

File details

Details for the file torchastic-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: torchastic-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 10.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.6

File hashes

Hashes for torchastic-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 ea7588ed29e7961140594f0bd5d1cb6e8fc3f1543d7a71be651cfc27a9a5a1dd
MD5 04dd54ff4f0f3dd71f06311b0722f98e
BLAKE2b-256 82d4d9640f7da637b807d1966b927dcddd4678e89faaf072586b65f289710e2e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page