Skip to main content

AdagradW - Decoupled Weight Decay in Adagrad

Project description

adagradw - Decoupled Weight Decay for Adagrad

The Adagrad optimizer preconditions the gradient update by the accumulated sum of squares of gradients. Assuming a batch size of 1 and learning rate of 0, this accumulated sum of squares of gradients is a diagonal approximation to the Empirical Fisher information matrix. However, using weight decay, the original PyTorch implementation updates the gradient with the gradient of the regularization term before computing the squared gradient, in which case the earlier equivalence to the empirical Fisher information does not hold anymore. This implementation applies the same trick as the AdamW optimizer to decouple weight decay and learning rate in Adam, by directly applying the regularizer's gradient step rather than combining gradients, thus recovering the equivalence with approximate computation of the Empirical Fisher information even when using weight decay.

Usage

Install with pip install adagradw and then just use as drop-in replacement

- optim = torch.optim.Adagrad(model.parameters(), lr=1e-3)
+ optim = adagradw.AdagradW(model.parameters(), lr=1e-3)

  for x, y in train_dataloader:
      loss = loss_fn(model(x), y)
      
      optim.zero_grad()
      loss.backward()
      optim.step()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

adagradw-0.0.4.tar.gz (5.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

adagradw-0.0.4-py3-none-any.whl (5.5 kB view details)

Uploaded Python 3

File details

Details for the file adagradw-0.0.4.tar.gz.

File metadata

  • Download URL: adagradw-0.0.4.tar.gz
  • Upload date:
  • Size: 5.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for adagradw-0.0.4.tar.gz
Algorithm Hash digest
SHA256 f899140f60713e8d6ddf2a092668692e509b7c37cc1e08891903d20344339951
MD5 5b1e16fe00f53952d47df04d6a9d2fce
BLAKE2b-256 0b2e46240967243c5781a0529dd89212f60542307a76b0fa6e1fbaec6fe0ab10

See more details on using hashes here.

File details

Details for the file adagradw-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: adagradw-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 5.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for adagradw-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 899f1d29c32fbba04e5efb9100be2e4a4bd8c64b48169b7434dbc15f778fd3b6
MD5 8cd1303fdb98d052013bdea7ddad147f
BLAKE2b-256 1d4fbd52d1a94e454455262223aef54bb0e79611846e9be59fe3b65a65977dcc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page