Skip to main content

PyTorch implementation of SWATS algorithm.

Project description

Switching from Adam to SGD

Wilson et al. (2018) shows that "the solutions found by adaptive methods generalize worse (often significantly worse) than SGD, even when these solutions have better training performance. These results suggest that practitioners should reconsider the use of adaptive methods to train neural networks."

"SWATS from Keskar & Socher (2017) a high-scoring paper by ICLR in 2018, a method proposed to automatically switch from Adam to SGD for better generalization performance. The idea of the algorithm itself is very simple. It uses Adam, which works well despite minimal tuning, but after learning until a certain stage, it is taken over by SGD."

Usage

Installing the package is straightforward with pip directly from this git repository or from pypi with either of the following commands.

pip install git+https://github.com/Mrpatekful/swats
pip install pytorch-swats

After installation SWATS can be used as any other torch.optim.Optimizer. The following code snippet serves as a simple overview of how to use the algorithm.

import swats

optimizer = swats.SWATS(model.parameters())
data_loader = torch.utils.data.DataLoader(...)

for epoch in range(10):
    for inputs, targets in data_loader:
        # deleting the stored grad values
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()

        # performing parameter update
        optimizer.step()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-swats-0.1.0.tar.gz (3.9 kB view details)

Uploaded Source

Built Distribution

pytorch_swats-0.1.0-py3-none-any.whl (5.3 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-swats-0.1.0.tar.gz.

File metadata

  • Download URL: pytorch-swats-0.1.0.tar.gz
  • Upload date:
  • Size: 3.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.7

File hashes

Hashes for pytorch-swats-0.1.0.tar.gz
Algorithm Hash digest
SHA256 63a8c0b61f78b96aa57624878f954372d21eea8b0bcea77f1efd2024b1076787
MD5 d8d96136213ecbfd7b2f8bdb73539c02
BLAKE2b-256 683199090940509e24966dfa34545a93d9e078a5a440e01c1dc2cb06c55e9c88

See more details on using hashes here.

File details

Details for the file pytorch_swats-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pytorch_swats-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 5.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.7

File hashes

Hashes for pytorch_swats-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 11e84ac4eef22e92776f8fd151845cc33fb36ffab903fd2c9514375a1b4bb149
MD5 e48589788a343376fee977158e9d4d10
BLAKE2b-256 1264a759b912da86176447520a39ca06825c599da75716276d92846ea9a6ac99

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page