Skip to main content

SGD-SaI Optimizer Implementation, designed for pytorch specificly.

Project description

SGD-SaI

This repository contains the official PyTorch implementation of the paper "No More Adam: Learning Rate Scaling at Initialization is All You Need" at arxiv, providing a runtime and memory efficient optimizer for training large language models and diffusion models, named "SGD-SaI".

In this work, we question the necessity of adaptive gradient methods for training deep neural networks. To address this, we introduce SGD-SaI, a simple yet effective enhancement to stochastic gradient descent with momentum (SGDM). SGD-SaI assigns distinct preconditioned learning rate scales to various parameter groups based on their initial gradient signal-to-noise ratios (g-SNR). Unlike Adam-like methods that depend on second-order momentum to adjust learning rates in response to imbalances detected in the gradient norm history, g-SNR enables us to identify gradient variance prior to the start of training. This proactive adjustment of learning rate scales minimizes the variance of parameter updates throughout the training process.


Figure1:This graph illustrates the differences in local gain behaviours exhibited by four optimizers throughout the training process. We present two popular adaptive gradient methods: Adam(W) and the memory-efficient Adam-mini. The local gains for these methods are recalculated continuously at each step based on the gradients. In contrast, SGD and SGD-SaI are both non-adaptive methods, meaning their local gains remain fixed throughout the training.

Results


Figure2:This figure displays the results from a grid search conducted on the classic ResNet18 model using the CIFAR10 dataset and ViT-S/16 model using the ImageNet-1K dataset. The maximum top-1 test accuracy is highlighted in red text. Our method surpasses other popular optimizers, achieving the highest test accuracy in traditional CNNs. Additionally, our method demonstrates superior performance compared to other popular optimizers, particularly in terms of stability in response to changes in hyperparameters.


Figure3: This figures shows the pre-train results on GPT2 and ViT. Our method save 50% memory usage of the state tensors of AdamW, and it is 3x faster than AdamMini. And for ViT, although our method has a slower convergence speed, we can still achieve comparable performance by the end of the training process. Additionally, our approach is designed to have a lower memory footprint and a faster optimization speed.

Efficiency


Figure4: The chart illustrates how memory usage and optimizer step time (in wall-clock time) increase with larger model sizes. It highlights the substantial memory overhead of storing optimizer states as model sizes grow. SGD-SaI exhibits significantly lower memory usage than AdamW and has the shortest optimization step runtime. This runtime refers to the wall clock time required for the optimizer step function. All statistics were measured on a single NVIDIA A100-80GB.


Figure3: Left: The figure presents the results of a grid search conducted on the ViT-S/16 model using the ImageNet-1K dataset. For better visualization and comparison, the axis for the learning rate has been inverted for SGD-like optimizers. The maximum top-1 test accuracy is highlighted in red text. Additionally, our method demonstrates superior performance compared to other popular optimizers, particularly in terms of stability in response to changes in hyperparameters. Right:The pseudocode of the SGD-SaI optimizer.

How To Use

Installation

Prerequisites:

  • Python >= 3.6
  • PyTorch >= 1.7.0

Since most of this optimizer uses the native PyTorch APIs, it should have a wider compatibility with different versions of PyTorch. However, we recommend using the Pytorch 2.X version for better performance and compatibility.

Install from PyPI:

pip install SGD-SaI

Install from the source code:

git clone https://github.com/AnonymousAlethiometer/SGD_SaI.git

cd SGD_SaI

# you can use the flag "--use-feature=in-tree-build" to avoid the warning of "FutureWarning: The 'build' command is deprecated"
pip install . --use-feature=in-tree-build

# [Optional] Or you can use '-e' flag to install in editable mode
pip install -e . --use-feature=in-tree-build

Usgae of the optimizer:

The optimizer is normally used in the following way:

from sgd_sai import SGD_sai

# initialize the optimizer
optimizer = SGD_sai(model.parameters(), lr=lr, momentum=0.9, eps=1e-08, weight_decay=weight_decay)

for _ in range(steps):
    pred = model(input_ids)
    loss = loss_fn(pred, labels)
    # calculate the gradient
    loss.backward()
    # From version 1.0.3, we do not need to call the warmup_step() manually, it will be called automatically after the first gradient calculation. Rightnow, it is user-friendly for different dirstributed training frameworks.
    # # process the warmup step, only need once after the gradient is calculated
    # # if not hasattr(optimizer, 'has_warmup') and hasattr(optimizer, 'warmup_step'):
    # #     optimizer.warmup_step()
    # #     optimizer.has_warmup = True
    # # update the parameters
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

Distributed Training

Version before: 1.0.3.

For distributed training, you need to ensure to perform this g-SNR calculation (refer as the .warmup step()) on each worker. Even if you accidentally perform it multiple times, it will not affect the final result thanks to the stability of the g-SNR values. Feel free to use the optimizer in your own training scripts.

In most circumstances, you only need to replace the original optimizer with our optimizer, perform the .warmup step() after first gradient calculation (aka. the first effective invoke of loss.backwards()) and keep the rest of the code unchanged.

Version after: 1.0.3.

Nothing need to be changed, the optimizer will automatically perform the warmup step after the first gradient calculation. Juts use it like a normal Pytorch optimizer.

Example:

The CNN examples lie in the examples directory. It contains the training code for CNN models, as well as the profiling code for the optimizer perfomance evaluation.

Please follow the README in that directory will guide you to restore the environment. Due to the procedure of anonymization, although the main part has been leaved unchanged, some of the components may not be available, try to delete the unavailable resources as needed.

The ViT example will be released soon.

Acknowledgement

  1. The codebase is based on the timm:pytorch-image-models(ViT training example, release soon), NanoGPT and Adam-mini(GPT2 training) repository.

  2. We thanks for Pytorch Profiler for an accurate and efficient way to profile the memory usage of the optimizer.

Citation

If you find this work helpful, please consider citing our paper:

@article{XXXXXXXXXX,
  title={No More Adam: Learning Rate Scaling at Initialization is All You Need},
  author={Anonymous},
  journal={arXiv preprint arXiv:24XX.XXXXX},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sgd_sai-1.0.3.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sgd_sai-1.0.3-py2.py3-none-any.whl (11.0 kB view details)

Uploaded Python 2Python 3

File details

Details for the file sgd_sai-1.0.3.tar.gz.

File metadata

  • Download URL: sgd_sai-1.0.3.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.16

File hashes

Hashes for sgd_sai-1.0.3.tar.gz
Algorithm Hash digest
SHA256 7418e5a90df2725e8b65d05d0edfa2b00f292f5ae27bab00f53053a21380c1d8
MD5 6d9a39662a11251dbf83ab0bea24a2ec
BLAKE2b-256 783781cd8a1cf8926a60926236d13977731e97a4e344c8844956cf00169e2d45

See more details on using hashes here.

File details

Details for the file sgd_sai-1.0.3-py2.py3-none-any.whl.

File metadata

  • Download URL: sgd_sai-1.0.3-py2.py3-none-any.whl
  • Upload date:
  • Size: 11.0 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.16

File hashes

Hashes for sgd_sai-1.0.3-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 5f5e6b5d8ec5371ce4ed297bafd16b4aa9b41e45643558a4ad6371902a593437
MD5 e39445dfe30e1dbef1e3054ecb107381
BLAKE2b-256 2587969191bfba9506c22a34467054898d07b92a6a75ffcce2bf84c1d380f6c7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page