Skip to main content

8-bit Adafactor Optimizer with Fused CUDA Kernels

Project description

English | 中文

Adafactor 8-bit with Fused CUDA Kernels

An 8-bit Adafactor optimizer designed for memory-efficient large-scale model training.

It uses fused CUDA kernels and block-wise quantization to reduce optimizer state memory while maintaining training stability, making it suitable for training large models such as LLMs and diffusion models.

Key Features

  • Fused CUDA Kernel: Integrates dequantization, EMA updates, Warp-Shuffle reduction, and requantization into a single kernel, utilizing float4 vectorization to maximize memory bandwidth utilization.
  • Zero CPU-GPU Sync: Refactored the control flow to eliminate implicit synchronizations, ensuring the GPU computation pipeline runs asynchronously at high speed.
  • Cross-Platform JIT: Utilizes JIT (Just-In-Time) automatic compilation for seamless setup across Windows and Linux environments.

Algorithm Details

Rebuilt upon the official PyTorch Adafactor, the mathematical logic aligns more closely with the original paper and HuggingFace transformers. Key differences include:

  1. Safe Injection of eps1: The official PyTorch implementation defaults to eps1=None and relies on clamp, which can lead to NaNs when encountering zero or extremely small gradients. This project adopts the original grad_squared + eps1 approach, fundamentally guaranteeing the strict positive definiteness of the second moment and preventing training crashes caused by rsqrt(0).
  2. Coupled Weight Decay: Unlike the official PyTorch implementation which decouples Weight Decay from RMS, this project retains the Coupled mechanism from the original paper (Weight Decay multiplied by the effective learning rate that includes RMS scaling).
  3. Standard Parameter Support: Fully retains core Adafactor switches such as relative_step and scale_parameter, ensuring compatibility with existing learning rate scheduling strategies.

Performance

  • Memory Footprint: The memory usage of optimizer states is significantly lower than AdamW8Bit (bitsandbytes), making it an ideal choice for training massive models or when memory-constrained.
  • Training Speed: The Fused Kernel and Zero-Sync design enable it to achieve step speeds comparable to mainstream 8-bit optimizers.
  • Quantization Precision & Stability: The second moment (variance) in Adafactor is always non-negative, so we map it to UINT8 (0~255). Compared to traditional 8-bit optimizers that map to INT8 (-127~127), providing higher effective quantization precision within the non-negative variance domain.

Installation

This project uses JIT (Just-In-Time) compilation.

Please ensure torch and ninja are installed, and a CUDA compiler (such as MSVC or GCC) is available in your environment.

If CUDA compilation fails, the optimizer will automatically fall back to the pure PyTorch implementation.

pip install git+https://github.com/yanfeiwong/adafactor-8bit.git

Usage Example

It is recommended to use param_groups to keep sensitive layers (Embedding, Norm, Bias) in FP32, enabling 8-bit quantization only for large 2D weight matrices.

import torch
import torch.nn as nn
from adafactor8bit import Adafactor8Bit

def get_param_groups(model, weight_decay=1e-2):
    decay, no_decay = [], []
    for name, param in model.named_parameters():
        if not param.requires_grad: continue
        # Protect 1D tensors, biases, norms, and embeddings
        if param.ndim <= 1 or "bias" in name or "norm" in name or "embed" in name:
            no_decay.append(param)
        else:
            decay.append(param)
            
    return [
        {"params": decay, "weight_decay": weight_decay, "quantize": True},
        {"params": no_decay, "weight_decay": 0.0, "quantize": False}
    ]

model = MyModel().cuda()
optimizer = Adafactor8Bit(
    get_param_groups(model), 
    lr=1e-3, 
    relative_step=False,
    block_size=2048,
    min_8bit_size=4096
)

# Training loop...

For a complete example, please refer to basic_usage.py.

Acknowledgements

Thanks to the large language models Qwen and DeepSeek for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.

Thanks to Tim Dettmers for the inspiration from the paper 8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION and the bitsandbytes library.

Thanks to the PyTorch team for providing the foundational Optimizer implementation and the C++ Extension toolchain.

License

The project is released under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

adafactor8bit-0.1.0.tar.gz (12.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

adafactor8bit-0.1.0-py3-none-any.whl (10.5 kB view details)

Uploaded Python 3

File details

Details for the file adafactor8bit-0.1.0.tar.gz.

File metadata

  • Download URL: adafactor8bit-0.1.0.tar.gz
  • Upload date:
  • Size: 12.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for adafactor8bit-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b5764370bca7ee156232207dd39d00ac1956f4209dfd6bb50394671de33d1984
MD5 6669652ed15bb4c30f8e5ca0ac210bc2
BLAKE2b-256 eb0c49a32e7f3e628d33f2c7014724b407ed600ecd30b6bf0bf2a15ad810d07a

See more details on using hashes here.

File details

Details for the file adafactor8bit-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: adafactor8bit-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 10.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for adafactor8bit-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6c14389c6fe706a4ad4209783d1ec912262fb5edaa8587f27c95dbb8c7d8779f
MD5 0713447dd18efa20c1c8a095fb018217
BLAKE2b-256 674c259bd9e8339228d520e3707ac1c8d188594c0233f7a2f94999bf501e8d49

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page