Skip to main content

Implementation for Pytorch of the method described in our paper "Bolstering Stochastic Gradient Descent with Model Building", S. Ilker Birbil, Ozgur Martin, Gonenc Onay, Figen Oztoprak, 2021 (see https://arxiv.org/abs/2111.07058)

Project description

Stochastic Model Building (SMB)

This repository includes a new fast and robust stochastic optimization algorithm for training deep learning models. The core idea of the algorithm is based on building models with local stochastic gradient information. The details of the algorithm is given in our recent paper.

SMB

Abstract

Stochastic gradient descent method and its variants constitute the core optimization algorithms that achieve good convergence rates for solving machine learning problems. These rates are obtained especially when these algorithms are fine-tuned for the application at hand. Although this tuning process can require large computational costs, recent work has shown that these costs can be reduced by line search methods that iteratively adjust the stepsize. We propose an alternative approach to stochastic line search by using a new algorithm based on forward step model building. This model building step incorporates a second-order information that allows adjusting not only the stepsize but also the search direction. Noting that deep learning model parameters come in groups (layers of tensors), our method builds its model and calculates a new step for each parameter group. This novel diagonalization approach makes the selected step lengths adaptive. We provide convergence rate analysis, and experimentally show that the proposed algorithm achieves faster convergence and better generalization in most problems. Moreover, our experiments show that the proposed method is quite robust as it converges for a wide range of initial stepsizes.

Keywords: model building; second-order information; stochastic gradient descent; convergence analysis

Installation

pip install git+https://github.com/sbirbil/SMB.git

Testing

Here is how you can use SMB:

import smb

optimizer = smb.SMB(model.parameters(), independent_batch=False) #independent_batch=True for SMBi optimizer

for epoch in range(100):
    
    # training steps
    model.train()
    
    for batch_index, (data, target) in enumerate(train_loader):
            
        # create loss closure for smb algorithm
        def closure():
            optimizer.zero_grad()
            loss = torch.nn.CrossEntropyLoss()(model(data), target)
            return loss
        
        # forward pass
        loss = optimizer.step(closure=closure)

You can also check our tutorial for a complete example (or the Colab notebook without installation). Set the hyper-parameter independent_batch to True in order to use the SMBi optimizer. Our paper includes more information.

Reproducing The Experiments

See the following script in order to reproduce the results in our paper.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

smb-optimizer-0.1.1.tar.gz (9.0 kB view details)

Uploaded Source

Built Distribution

smb_optimizer-0.1.1-py3-none-any.whl (7.8 kB view details)

Uploaded Python 3

File details

Details for the file smb-optimizer-0.1.1.tar.gz.

File metadata

  • Download URL: smb-optimizer-0.1.1.tar.gz
  • Upload date:
  • Size: 9.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.6 CPython/3.7.10 Linux/5.4.0-90-generic

File hashes

Hashes for smb-optimizer-0.1.1.tar.gz
Algorithm Hash digest
SHA256 c1460a332587546e13d44d94a00878eb857648a999cf6ba0b8cb551e255b1ba2
MD5 2c00ae3e58adb0bb5bcbc39d91e49c97
BLAKE2b-256 5ea1e0432cfbe7489592e61eee19685e227dee5f9a1161959d83af4baf78cb57

See more details on using hashes here.

File details

Details for the file smb_optimizer-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: smb_optimizer-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 7.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.6 CPython/3.7.10 Linux/5.4.0-90-generic

File hashes

Hashes for smb_optimizer-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f642b6f787c52f162ac94dded81a9cab23f90d8ffd8cb8ff521db79a625363d7
MD5 d5eaff2559bf91dcc225d734b0f1cbd2
BLAKE2b-256 ba3be5b6b282f68ae965788fda24fc86c6e851f7f79c4568e6ca768d6e9cb6eb

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page