SGD-Boost Optimizer Implementation, designed for pytorch specificly.
Project description
SGD_Boost
This repository contains the official PyTorch implementation of the paper "Why Transformers Don’t Need Adam: Scale Is All You Need" at arxiv, providing a runtime and memory efficient optimizer for training large language models and diffusion models, named "SGD-Boost".
Adaptive gradient methods like Adam and AdamW are popular for training transformer-based models due to their strong performance, but they have significant memory overhead and require extensive hyperparameter tuning. In this work, we argue that these adaptive methods are not necessary for effective training.
SGD-Boost introduces a learning rate rescaling strategy based on initial gradient patterns, applied to stochastic gradient descent with momentum (SGDM). This simple modification allows SGDM to achieve performance levels comparable to AdamW while reducing memory consumption and execution time. By removing the need to store second-order momentum terms, our approach reduces optimizer state memory by half, providing a “free lunch” in training efficiency.
Our method also enhances robustness to variations in learning rate and weight decay during ViT training on the Imagenet-1K task. Experimental results show that it outperforms existing optimizers in LoRA training for both large language models (LLMs) and diffusion models (DMs). Specifically, it enables full precision (FP32) training of GPT-2 (1.5B parameters) on a single RTX3090 and Llama2-7B on an A100-80G GPU. Code is now available at GitHub.
Figure1: We analyze four key parameters: the weights of the Query, Key, and Value (QKV) in the first Transformer block; the normalization layer; the fully connected layers within that block; and the final MLP head layer. The gradient signal-to-noise ratio (g-SNR) differs across various parameter groups but remains stable throughout the training process. We utilize this signal to create a scaling strategy that adjusts the fixed learning rates in Stochastic Gradient Descent (SGD).
Figure2:
Left:The figure shows the significant memory overhead for optimizer states with increasing model sizes. SGD-Boost maintains a much lower memory usage compared to other optimizers.
Right:This figure displays the results from a grid search conducted on the classic ResNet18 model using the CIFAR10 dataset. The maximum top-1 test accuracy is highlighted in red text. Our method surpasses other popular optimizers, achieving the highest test accuracy.
Figure3: The pseudocode of the SGD-Boost optimizer.
How To Use
Installation
Prerequisites:
- Python >= 3.6
- PyTorch >= 1.7.0
Since most of this optimizer uses the native PyTorch APIs, it should have a wider compatibility with different versions of PyTorch. However, we recommend using the Pytorch 2.X version for better performance and compatibility.
Install from PyPI:
pip install sgd-boost
Install from the source code:
git clone https://github.com/AnonymousAlethiometer/SGD_Boost.git
cd SGD_Boost
# you can use the flag "--use-feature=in-tree-build" to avoid the warning of "FutureWarning: The 'build' command is deprecated"
pip install . --use-feature=in-tree-build
# [Optional] Or you can use '-e' flag to install in editable mode
pip install -e . --use-feature=in-tree-build
Usgae of the optimizer:
The optimizer is normally used in the following way:
from sgd_boost import SGDBoost
# initialize the optimizer
optimizer = SGD_boost(model.parameters(), lr=lr, momentum=0.9, eps=1e-08, weight_decay=weight_decay)
for _ in range(steps):
pred = model(input_ids)
loss = loss_fn(pred, labels)
# calculate the gradient
loss.backward()
# process the warmup step, only need once after the gradient is calculated
if not hasattr(optimizer, 'has_warmup') and hasattr(optimizer, 'warmup_step'):
optimizer.warmup_step()
optimizer.has_warmup = True
# update the parameters
optimizer.step()
optimizer.zero_grad(set_to_none=True)
Distributed Training
For distributed training, you need to ensure to perform this g-SNR calculation (refer as the .warmup step()
) on each worker. Even if you accidentally perform it multiple times, it will not affect the final result thanks to the stability of the g-SNR values. Feel free to use the optimizer in your own training scripts.
In most circumstances, you only need to replace the original optimizer with our optimizer, perform the .warmup step()
after first gradient calculation (aka. the first effective invoke of loss.backwards()
) and keep the rest of the code unchanged.
Example:
The CNN examples lie in the examples
directory. It contains the training code for CNN models, as well as the profiling code for the optimizer perfomance evaluation.
Please follow the README in that directory will guide you to restore the environment. Due to the procedure of anonymization, although the main part has been leaved unchanged, some of the components may not be available, try to delete the unavailable resources as needed.
The ViT example will be released soon.
Acknowledgement
-
The codebase is based on the timm:pytorch-image-models(ViT training example, release soon), NanoGPT and Adam-mini(GPT2 training) repository.
-
We thanks for Pytorch Profiler for an accurate and efficient way to profile the memory usage of the optimizer.
Citation
If you find this work helpful, please consider citing our paper:
@article{XXXXXXXXXX,
title={Why Transformers Don’t Need Adam: Scale Is All You Need},
author={Anonymous},
journal={arXiv preprint arXiv:24XX.XXXXX},
year={2024}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
File details
Details for the file sgd_boost-1.0.2.tar.gz
.
File metadata
- Download URL: sgd_boost-1.0.2.tar.gz
- Upload date:
- Size: 11.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.10.0 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/1.0.0 urllib3/1.26.20 tqdm/4.64.1 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c75b6e67f2a6aa4f28149dfdb33f2f2290d0d396140986a7b0be36a68ca2b87b |
|
MD5 | 6d75711ad7288aecec271dc215008290 |
|
BLAKE2b-256 | 1246a8b3704643c692a90b9a6562902f44ea191009a49eec903f3454c782b685 |
File details
Details for the file sgd_boost-1.0.2-py3-none-any.whl
.
File metadata
- Download URL: sgd_boost-1.0.2-py3-none-any.whl
- Upload date:
- Size: 10.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.10.0 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/1.0.0 urllib3/1.26.20 tqdm/4.64.1 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 762a70a1d3b41dc0373abc6bd01a595950f246202b834253d690c37596eeb1ca |
|
MD5 | ad697cf199ac477a53cc37cc8c7d234b |
|
BLAKE2b-256 | 448ec600cb9b9c7418119100a2248aab9eebecbeb45bae82276691000a8d7386 |
File details
Details for the file sgd_boost-1.0.2-py2.py3-none-any.whl
.
File metadata
- Download URL: sgd_boost-1.0.2-py2.py3-none-any.whl
- Upload date:
- Size: 10.6 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.10.0 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/1.0.0 urllib3/1.26.20 tqdm/4.64.1 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 85ef2515a4c271ce619923def334623b4c617e56de970ee0140ef092d4e68d84 |
|
MD5 | 0cdfb99b80ee7dbaff3012f36337dc72 |
|
BLAKE2b-256 | 564ef3c426968366251644937501ef733c35f5a575de9f555de9594c8ee2ad51 |