Statistical adaptive stochastic optimization methods
Project description
Statistical Adaptive Stochastic Gradient Methods
A package of PyTorch optimizers that can automatically schedule learning rates based on online statistical tests.
- main algorithms: SALSA and SASA
- auxiliary codes: QHM and SSLS
Companion paper: Statistical Adaptive Stochastic Gradient Methods by Zhang, Lang, Liu and Xiao, 2020.
Install
pip install statopt
Or from Github:
pip install git+git://github.com/microsoft/statopt.git#egg=statopt
Usage of SALSA and SASA
Here we outline the key steps on CIFAR10. Complete Python code is given in examples/cifar_example.py.
Common setups
First, choose a batch size and prepare the dataset and data loader as in this PyTorch tutorial:
import torch, torchvision
batch_size = 128
trainset = torchvision.datasets.CIFAR10(root='../data', train=True, ...)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, ...)
Choose device, network model, and loss function:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = torchvision.models.resnet18().to(device)
loss_func = torch.nn.CrossEntropyLoss()
SALSA
Import statopt
, and initialize SALSA with a small learning rate and two extra parameters:
import statopt
gamma = math.sqrt(batch_size/len(trainset)) # smoothing parameter for line search
testfreq = min(1000, len(trainloader)) # frequency to perform statistical test
optimizer = statopt.SALSA(net.parameters(), lr=1e-3, # any small initial learning rate
momentum=0.9, weight_decay=5e-4, # common choices for CIFAR10/100
gamma=gamma, testfreq=testfreq) # two extra parameters for SALSA
Training code using SALSA
for epoch in range(100):
for (images, labels) in trainloader:
net.train() # always switch to train() mode
# Compute model outputs and loss function
images, labels = images.to(device), labels.to(device)
loss = loss_func(net(images), labels)
# Compute gradient with back-propagation
optimizer.zero_grad()
loss.backward()
# SALSA requires a closure function for line search
def eval_loss(eval_mode=True):
if eval_mode:
net.eval()
with torch.no_grad():
loss = loss_func(net(images), labels)
return loss
optimizer.step(closure=eval_loss)
SASA
SASA requires a good (hand-tuned) initial learning rate like most other optimizers, but do not use line search:
optimizer = statopt.SASA(net.parameters(), lr=1.0, # need a good initial learning rate
momentum=0.9, weight_decay=5e-4, # common choices for CIFAR10/100
testfreq=testfreq) # frequency for statistical tests
Within the training loop: optimizer.step()
does NOT need any closure function.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file statopt-0.2.tar.gz
.
File metadata
- Download URL: statopt-0.2.tar.gz
- Upload date:
- Size: 12.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e522e36218abe879640cf125bac29371ee051dd5f75b72a04772a38b9ce53360 |
|
MD5 | c931f2f1af824f66b44999f188fafc3f |
|
BLAKE2b-256 | 8924836a5a115edf5fb0f3a81891f93b21df081f5424f36076e4273f8c5681cb |
File details
Details for the file statopt-0.2-py3-none-any.whl
.
File metadata
- Download URL: statopt-0.2-py3-none-any.whl
- Upload date:
- Size: 17.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 061f96a60eeca44ce2279876fcc93ec877b053785945aaf1b27361c9766ddf57 |
|
MD5 | ad48f09eabd30034199af961b7698d64 |
|
BLAKE2b-256 | 61a9f0f6f45f3bd55926ca86f99e719bc137d05a374957549b49f67755b70633 |