Skip to main content

8-bit Adafactor Optimizer with Fused CUDA Kernels

Project description

English | 中文

Adafactor 8-bit with Fused CUDA Kernels

PyPI version License: MIT Python 3.10+ GitHub Stars

An 8-bit Adafactor optimizer designed for memory-efficient large-scale model training.

It uses fused CUDA kernels and log-space block-wise quantization to reduce optimizer state memory while maintaining training stability, making it suitable for training large models such as LLMs and diffusion models.

Key Features

  • Log-Space Quantization: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
  • Fused CUDA Kernels: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes float4 vectorization to optimize memory bandwidth usage.
  • Zero CPU-GPU Sync: Eliminates implicit synchronizations (e.g., D2H copies) in the control flow, ensuring the GPU computation pipeline runs without blocking.
  • Cross-Platform JIT: Uses Just-In-Time (JIT) compilation for straightforward setup across both Windows and Linux environments.

Performance

  • Memory Footprint: Due to Adafactor's factorized second-moment estimation and 8-bit quantization, the optimizer state memory usage is generally lower than that of AdamW8Bit.
  • Training Speed: The fused kernel design and reduced synchronization overhead allow it to achieve step times comparable to other mainstream 8-bit optimizers.
  • Quantization Precision: The second moment (variance) in Adafactor is strictly non-negative and spans multiple orders of magnitude. By mapping it to UINT8 in log2 space rather than linear space, the optimizer preserves relative precision for small variances, mitigating the instability often caused by outlier gradients in standard 8-bit quantization.

Installation

This project uses JIT (Just-In-Time) compilation.

Please ensure torch and ninja are installed, and a CUDA compiler (such as MSVC or GCC) is available in your environment.

If CUDA compilation fails, the optimizer will automatically fall back to the pure PyTorch implementation.

From PyPI

pip install -U adafactor8bit

From Source

pip install git+https://github.com/yanfeiwong/adafactor-8bit.git

Note: The first time you instantiate the optimizer (or run the example script), it will automatically trigger the JIT compilation of the CUDA source code in the background. This may take anywhere from a few seconds to a couple of minutes depending on your system, and the terminal might appear unresponsive. Once compiled, the binary will be cached, and all subsequent runs will be instantaneous.

Usage Example

It is recommended to use param_groups to keep sensitive layers (Embedding, Norm, Bias) in FP32, enabling 8-bit quantization only for large 2D weight matrices.

import torch
import torch.nn as nn
from adafactor8bit import Adafactor8Bit

def get_param_groups(model, weight_decay=1e-2):
    decay, no_decay = [], []
    for name, param in model.named_parameters():
        if not param.requires_grad: continue
        # Protect 1D tensors, biases, norms, and embeddings
        if param.ndim <= 1 or "bias" in name or "norm" in name or "embed" in name:
            no_decay.append(param)
        else:
            decay.append(param)
            
    return [
        {"params": decay, "weight_decay": weight_decay, "quantize": True},
        {"params": no_decay, "weight_decay": 0.0, "quantize": False}
    ]

model = MyModel().cuda()
optimizer = Adafactor8Bit(
    get_param_groups(model), 
    lr=1e-3, 
    relative_step=False,
)

# Training loop...

For a complete example, please refer to basic_usage.py.

Acknowledgements

Thanks to the large language models Qwen and DeepSeek for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.

Thanks to Tim Dettmers for the inspiration from the paper 8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION and the bitsandbytes library.

Thanks to the PyTorch team for providing the foundational Optimizer implementation and the C++ Extension toolchain.

License

The project is released under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

adafactor8bit-0.1.4.tar.gz (14.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

adafactor8bit-0.1.4-py3-none-any.whl (12.6 kB view details)

Uploaded Python 3

File details

Details for the file adafactor8bit-0.1.4.tar.gz.

File metadata

  • Download URL: adafactor8bit-0.1.4.tar.gz
  • Upload date:
  • Size: 14.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for adafactor8bit-0.1.4.tar.gz
Algorithm Hash digest
SHA256 9694e1869cc774492003a6e29f68cbe2428503aadf0395c825452e0415406f5e
MD5 c4ea33d98aa9011b7aae6fb6757a4c32
BLAKE2b-256 1e6afaacd374ec9f8ce318d08451bc45f36261c97ffe85de33b6bdd51bf9c5df

See more details on using hashes here.

File details

Details for the file adafactor8bit-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: adafactor8bit-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 12.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for adafactor8bit-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 a82f96b9603dd599d04044c25de08ef7e3641f21e2b5f3e40866ae85e1fef372
MD5 569a08ca2bf11342adf35d8e0982a64e
BLAKE2b-256 6919088db884d0c9e6092ab36c1e6ecdd85b0916867e322745ce06256f6c6844

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page