8-bit Adafactor Optimizer with Fused CUDA Kernels
Reason this release was yanked:
Refined N-D tensor factorization logic to strictly align with the original Adafactor paper for 3D/4D layers (e.g., Conv2d, MoE). v0.1.1 may exhibit suboptimal training dynamics for these layers. Please upgrade to adafactor8bit>=0.1.2 for improved stability.
Project description
English | 中文
Adafactor 8-bit with Fused CUDA Kernels
An 8-bit Adafactor optimizer designed for memory-efficient large-scale model training.
It uses fused CUDA kernels and block-wise quantization to reduce optimizer state memory while maintaining training stability, making it suitable for training large models such as LLMs and diffusion models.
Key Features
- Fused CUDA Kernel: Integrates dequantization, EMA updates, Warp-Shuffle reduction, and requantization into a single kernel, utilizing
float4vectorization to maximize memory bandwidth utilization. - Zero CPU-GPU Sync: Refactored the control flow to eliminate implicit synchronizations, ensuring the GPU computation pipeline runs asynchronously at high speed.
- Cross-Platform JIT: Utilizes JIT (Just-In-Time) automatic compilation for seamless setup across Windows and Linux environments.
Algorithm Details
Rebuilt upon the official PyTorch Adafactor, the mathematical logic aligns more closely with the original paper and HuggingFace transformers. Key differences include:
- Safe Injection of
eps1: The official PyTorch implementation defaults toeps1=Noneand relies onclamp, which can lead to NaNs when encountering zero or extremely small gradients. This project adopts the originalgrad_squared + eps1approach, fundamentally guaranteeing the strict positive definiteness of the second moment and preventing training crashes caused byrsqrt(0). - Coupled Weight Decay: Unlike the official PyTorch implementation which decouples Weight Decay from RMS, this project retains the Coupled mechanism from the original paper (Weight Decay multiplied by the effective learning rate that includes RMS scaling).
- Standard Parameter Support: Fully retains core Adafactor switches such as
relative_stepandscale_parameter, ensuring compatibility with existing learning rate scheduling strategies.
Performance
- Memory Footprint: The memory usage of optimizer states is significantly lower than
AdamW8Bit(bitsandbytes), making it an ideal choice for training massive models or when memory-constrained. - Training Speed: The Fused Kernel and Zero-Sync design enable it to achieve step speeds comparable to mainstream 8-bit optimizers.
- Quantization Precision & Stability: The second moment (variance) in Adafactor is always non-negative, so we map it to
UINT8 (0~255). Compared to traditional 8-bit optimizers that map toINT8 (-127~127), providing higher effective quantization precision within the non-negative variance domain.
Installation
This project uses JIT (Just-In-Time) compilation.
Please ensure torch and ninja are installed, and a CUDA compiler (such as MSVC or GCC) is available in your environment.
If CUDA compilation fails, the optimizer will automatically fall back to the pure PyTorch implementation.
From PyPI
pip install -U adafactor8bit
From Source
pip install git+https://github.com/yanfeiwong/adafactor-8bit.git
Note: The first time you instantiate the optimizer (or run the example script), it will automatically trigger the JIT compilation of the CUDA source code in the background. This may take anywhere from a few seconds to a couple of minutes depending on your system, and the terminal might appear unresponsive. Please be patient. Once compiled, the binary will be cached, and all subsequent runs will be instantaneous.
Usage Example
It is recommended to use param_groups to keep sensitive layers (Embedding, Norm, Bias) in FP32, enabling 8-bit quantization only for large 2D weight matrices.
import torch
import torch.nn as nn
from adafactor8bit import Adafactor8Bit
def get_param_groups(model, weight_decay=1e-2):
decay, no_decay = [], []
for name, param in model.named_parameters():
if not param.requires_grad: continue
# Protect 1D tensors, biases, norms, and embeddings
if param.ndim <= 1 or "bias" in name or "norm" in name or "embed" in name:
no_decay.append(param)
else:
decay.append(param)
return [
{"params": decay, "weight_decay": weight_decay, "quantize": True},
{"params": no_decay, "weight_decay": 0.0, "quantize": False}
]
model = MyModel().cuda()
optimizer = Adafactor8Bit(
get_param_groups(model),
lr=1e-3,
relative_step=False,
block_size=2048,
min_8bit_size=4096
)
# Training loop...
For a complete example, please refer to basic_usage.py.
Acknowledgements
Thanks to the large language models Qwen and DeepSeek for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
Thanks to Tim Dettmers for the inspiration from the paper 8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION and the bitsandbytes library.
Thanks to the PyTorch team for providing the foundational Optimizer implementation and the C++ Extension toolchain.
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file adafactor8bit-0.1.1.tar.gz.
File metadata
- Download URL: adafactor8bit-0.1.1.tar.gz
- Upload date:
- Size: 14.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7c2cbd95b818e18744e77570d3eb5c988aa0b3b53049a5c64e9f9e190ad459b9
|
|
| MD5 |
2d09db331e90b4c348b2fecbd2b49191
|
|
| BLAKE2b-256 |
7ea97f5739383dbc93c44da1ef1fc218da5a97808c01f3f048db8531a9092f89
|
File details
Details for the file adafactor8bit-0.1.1-py3-none-any.whl.
File metadata
- Download URL: adafactor8bit-0.1.1-py3-none-any.whl
- Upload date:
- Size: 12.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d3b37e0c8e2f2bd974e6aa2841550508b96e49e79b034732377ac03848835f64
|
|
| MD5 |
7d674b0918ddd509b7db8a568063e17e
|
|
| BLAKE2b-256 |
6891364474ae0b6d12cc5b95db1bbdefaac5f47743562b26d5b73b3fd7979893
|