Skip to main content

SGD-Boost Optimizer Implementation, designed for pytorch specificly.

Project description

SGD_Boost

This repository contains the official PyTorch implementation of the paper "Why Transformers Don’t Need Adam: Scale Is All You Need" at arxiv, providing a runtime and memory efficient optimizer for training large language models and diffusion models, named "SGD-Boost".

Adaptive gradient methods like Adam and AdamW are popular for training transformer-based models due to their strong performance, but they have significant memory overhead and require extensive hyperparameter tuning. In this work, we argue that these adaptive methods are not necessary for effective training.

SGD-Boost introduces a learning rate rescaling strategy based on initial gradient patterns, applied to stochastic gradient descent with momentum (SGDM). This simple modification allows SGDM to achieve performance levels comparable to AdamW while reducing memory consumption and execution time. By removing the need to store second-order momentum terms, our approach reduces optimizer state memory by half, providing a “free lunch” in training efficiency.

Our method also enhances robustness to variations in learning rate and weight decay during ViT training on the Imagenet-1K task. Experimental results show that it outperforms existing optimizers in LoRA training for both large language models (LLMs) and diffusion models (DMs). Specifically, it enables full precision (FP32) training of GPT-2 (1.5B parameters) on a single RTX3090 and Llama2-7B on an A100-80G GPU. Code is now available at GitHub.


Figure1: We analyze four key parameters: the weights of the Query, Key, and Value (QKV) in the first Transformer block; the normalization layer; the fully connected layers within that block; and the final MLP head layer. The gradient signal-to-noise ratio (g-SNR) differs across various parameter groups but remains stable throughout the training process. We utilize this signal to create a scaling strategy that adjusts the fixed learning rates in Stochastic Gradient Descent (SGD).


Figure2: Left:The figure shows the significant memory overhead for optimizer states with increasing model sizes. SGD-Boost maintains a much lower memory usage compared to other optimizers. Right:This figure displays the results from a grid search conducted on the classic ResNet18 model using the CIFAR10 dataset. The maximum top-1 test accuracy is highlighted in red text. Our method surpasses other popular optimizers, achieving the highest test accuracy.


Figure3: The pseudocode of the SGD-Boost optimizer.

How To Use

Installation

Prerequisites:

  • Python >= 3.6
  • PyTorch >= 1.7.0

Since most of this optimizer uses the native PyTorch APIs, it should have a wider compatibility with different versions of PyTorch. However, we recommend using the Pytorch 2.X version for better performance and compatibility.

Install from PyPI:

pip install sgd-boost

Install from the source code:

git clone https://github.com/AnonymousAlethiometer/SGD_Boost.git

cd SGD_Boost

# you can use the flag "--use-feature=in-tree-build" to avoid the warning of "FutureWarning: The 'build' command is deprecated"
pip install . --use-feature=in-tree-build

# [Optional] Or you can use '-e' flag to install in editable mode
pip install -e . --use-feature=in-tree-build

Usgae of the optimizer:

The optimizer is normally used in the following way:

from sgd_boost import SGDBoost

# initialize the optimizer
optimizer = SGD_boost(model.parameters(), lr=lr, momentum=0.9, eps=1e-08, weight_decay=weight_decay)

for _ in range(steps):
    pred = model(input_ids)
    loss = loss_fn(pred, labels)
    # calculate the gradient
    loss.backward()
    # process the warmup step, only need once after the gradient is calculated
    if not hasattr(optimizer, 'has_warmup') and hasattr(optimizer, 'warmup_step'):
        optimizer.warmup_step()
        optimizer.has_warmup = True
    # update the parameters
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

Distributed Training

For distributed training, you need to ensure to perform this g-SNR calculation (refer as the .warmup step()) on each worker. Even if you accidentally perform it multiple times, it will not affect the final result thanks to the stability of the g-SNR values. Feel free to use the optimizer in your own training scripts.

In most circumstances, you only need to replace the original optimizer with our optimizer, perform the .warmup step() after first gradient calculation (aka. the first effective invoke of loss.backwards()) and keep the rest of the code unchanged.

Example:

The CNN examples lie in the examples directory. It contains the training code for CNN models, as well as the profiling code for the optimizer perfomance evaluation.

Please follow the README in that directory will guide you to restore the environment. Due to the procedure of anonymization, although the main part has been leaved unchanged, some of the components may not be available, try to delete the unavailable resources as needed.

The ViT example will be released soon.

Acknowledgement

  1. The codebase is based on the timm:pytorch-image-models(ViT training example, release soon), NanoGPT and Adam-mini(GPT2 training) repository.

  2. We thanks for Pytorch Profiler for an accurate and efficient way to profile the memory usage of the optimizer.

Citation

If you find this work helpful, please consider citing our paper:

@article{XXXXXXXXXX,
  title={Why Transformers Don’t Need Adam: Scale Is All You Need},
  author={Anonymous},
  journal={arXiv preprint arXiv:24XX.XXXXX},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sgd_boost-1.0.2.tar.gz (11.2 kB view details)

Uploaded Source

Built Distributions

sgd_boost-1.0.2-py3-none-any.whl (10.6 kB view details)

Uploaded Python 3

sgd_boost-1.0.2-py2.py3-none-any.whl (10.6 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file sgd_boost-1.0.2.tar.gz.

File metadata

  • Download URL: sgd_boost-1.0.2.tar.gz
  • Upload date:
  • Size: 11.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.10.0 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/1.0.0 urllib3/1.26.20 tqdm/4.64.1 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.13

File hashes

Hashes for sgd_boost-1.0.2.tar.gz
Algorithm Hash digest
SHA256 c75b6e67f2a6aa4f28149dfdb33f2f2290d0d396140986a7b0be36a68ca2b87b
MD5 6d75711ad7288aecec271dc215008290
BLAKE2b-256 1246a8b3704643c692a90b9a6562902f44ea191009a49eec903f3454c782b685

See more details on using hashes here.

File details

Details for the file sgd_boost-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: sgd_boost-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 10.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.10.0 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/1.0.0 urllib3/1.26.20 tqdm/4.64.1 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.13

File hashes

Hashes for sgd_boost-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 762a70a1d3b41dc0373abc6bd01a595950f246202b834253d690c37596eeb1ca
MD5 ad697cf199ac477a53cc37cc8c7d234b
BLAKE2b-256 448ec600cb9b9c7418119100a2248aab9eebecbeb45bae82276691000a8d7386

See more details on using hashes here.

File details

Details for the file sgd_boost-1.0.2-py2.py3-none-any.whl.

File metadata

  • Download URL: sgd_boost-1.0.2-py2.py3-none-any.whl
  • Upload date:
  • Size: 10.6 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.10.0 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/1.0.0 urllib3/1.26.20 tqdm/4.64.1 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.13

File hashes

Hashes for sgd_boost-1.0.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 85ef2515a4c271ce619923def334623b4c617e56de970ee0140ef092d4e68d84
MD5 0cdfb99b80ee7dbaff3012f36337dc72
BLAKE2b-256 564ef3c426968366251644937501ef733c35f5a575de9f555de9594c8ee2ad51

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page