Skip to main content

Statistical adaptive stochastic optimization methods

Project description

Statistical Adaptive Stochastic Gradient Methods

A package of PyTorch optimizers that can automatically schedule learning rates based on online statistical tests.

  • main algorithms: SALSA and SASA
  • auxiliary codes: QHM and SSLS

Companion paper: Statistical Adaptive Stochastic Gradient Methods by Zhang, Lang, Liu and Xiao, 2020.

Install

pip install statopt

Or from Github:

pip install git+git://github.com/microsoft/statopt.git#egg=statopt

Usage of SALSA and SASA

Here we outline the key steps on CIFAR10. Complete Python code is given in examples/cifar_example.py.

Common setups

First, choose a batch size and prepare the dataset and data loader as in this PyTorch tutorial:

import torch, torchvision

batch_size = 128
trainset = torchvision.datasets.CIFAR10(root='../data', train=True, ...)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, ...)

Choose device, network model, and loss function:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = torchvision.models.resnet18().to(device)
loss_func = torch.nn.CrossEntropyLoss()

SALSA

Import statopt, and initialize SALSA with a small learning rate and two extra parameters:

import statopt

gamma = math.sqrt(batch_size/len(trainset))             # smoothing parameter for line search
testfreq = min(1000, len(trainloader))                  # frequency to perform statistical test 

optimizer = statopt.SALSA(net.parameters(), lr=1e-3,            # any small initial learning rate 
                          momentum=0.9, weight_decay=5e-4,      # common choices for CIFAR10/100
                          gamma=gamma, testfreq=testfreq)       # two extra parameters for SALSA

Training code using SALSA

for epoch in range(100):
    for (images, labels) in trainloader:
        net.train()	# always switch to train() mode

        # Compute model outputs and loss function 
        images, labels = images.to(device), labels.to(device)
        loss = loss_func(net(images), labels)

        # Compute gradient with back-propagation 
        optimizer.zero_grad()
        loss.backward()

        # SALSA requires a closure function for line search
        def eval_loss(eval_mode=True):
            if eval_mode:
                net.eval()
            with torch.no_grad():
                loss = loss_func(net(images), labels)
            return loss

        optimizer.step(closure=eval_loss)

SASA

SASA requires a good (hand-tuned) initial learning rate like most other optimizers, but do not use line search:

optimizer = statopt.SASA(net.parameters(), lr=1.0,              # need a good initial learning rate 
                         momentum=0.9, weight_decay=5e-4,       # common choices for CIFAR10/100
                         testfreq=testfreq)                     # frequency for statistical tests

Within the training loop: optimizer.step() does NOT need any closure function.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

statopt-0.1.tar.gz (12.2 kB view details)

Uploaded Source

Built Distribution

statopt-0.1-py3-none-any.whl (17.6 kB view details)

Uploaded Python 3

File details

Details for the file statopt-0.1.tar.gz.

File metadata

  • Download URL: statopt-0.1.tar.gz
  • Upload date:
  • Size: 12.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1

File hashes

Hashes for statopt-0.1.tar.gz
Algorithm Hash digest
SHA256 9ba09e8112b065f2e2f49af3a73c44dc9ed974bf2a748525db1b3aa29ad01f8b
MD5 e1d3cec46dbca0d81eeaea398834e6e1
BLAKE2b-256 45c81ee351af87fabfa935688880e372d042b293b21ee41844563bc7606be33b

See more details on using hashes here.

File details

Details for the file statopt-0.1-py3-none-any.whl.

File metadata

  • Download URL: statopt-0.1-py3-none-any.whl
  • Upload date:
  • Size: 17.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1

File hashes

Hashes for statopt-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4d09cb6b791738841f9935097204f4153019b8579510bdfe413a1085c25a36ff
MD5 6b1d9483a0936397b86cd3601ccb0216
BLAKE2b-256 8b577a4c171f3202fc42165516315b17dc22d297fce0ffb4ca7942522d69c096

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page