Skip to main content

PyTorch implementation of SWATS algorithm.

Project description

Switching from Adam to SGD

Wilson et al. (2018) shows that "the solutions found by adaptive methods generalize worse (often significantly worse) than SGD, even when these solutions have better training performance. These results suggest that practitioners should reconsider the use of adaptive methods to train neural networks."

"SWATS from Keskar & Socher (2017) a high-scoring paper by ICLR in 2018, a method proposed to automatically switch from Adam to SGD for better generalization performance. The idea of the algorithm itself is very simple. It uses Adam, which works well despite minimal tuning, but after learning until a certain stage, it is taken over by SGD."

Usage

Installing the package is straightforward with pip directly from this git repository or from pypi with either of the following commands.

pip install git+https://github.com/Mrpatekful/swats
pip install pytorch-swats

After installation SWATS can be used as any other torch.optim.Optimizer. The following code snippet serves as a simple overview of how to use the algorithm.

import swats

optimizer = swats.SWATS(model.parameters())
data_loader = torch.utils.data.DataLoader(...)

for epoch in range(10):
    for inputs, targets in data_loader:
        # deleting the stored grad values
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()

        # performing parameter update
        optimizer.step()

Project details


Release history Release notifications

This version

0.1.0

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for pytorch-swats, version 0.1.0
Filename, size File type Python version Upload date Hashes
Filename, size pytorch_swats-0.1.0-py3-none-any.whl (5.3 kB) File type Wheel Python version py3 Upload date Hashes View hashes
Filename, size pytorch-swats-0.1.0.tar.gz (3.9 kB) File type Source Python version None Upload date Hashes View hashes

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page