Skip to main content

An implementation of PSGD Kron optimizer in PyTorch.

Project description

PSGD Kron

For Xi-Lin's original PSGD repo, see psgd_torch.

For JAX versions, see psgd_jax and distributed_kron.

Implementation of PSGD Kron for PyTorch. PSGD is a second-order optimizer originally created by Xi-Lin Li that uses either a hessian-based or whitening-based (gg^T) preconditioner and lie groups to improve training convergence, generalization, and efficiency. I highly suggest taking a look at Xi-Lin's PSGD repo's readme linked to above for interesting details on how PSGD works and experiments using PSGD. There are also paper resources listed near the bottom of this readme.

kron:

The most versatile and easy-to-use PSGD optimizer is kron, which uses a Kronecker-factored preconditioner. It has less hyperparameters that need tuning than adam, and can generally act as a drop-in replacement.

Thanks

Shoutout to @ClashLuke for developing efficiency improvements for PSGD Kron in the heavyball repo, and for the design of 'smart_one_diag' memory save mode, which is a method to improve memory usage and speed with almost no cost to the optimizer's effectiveness. In Xi-Lin's repo, the equivalent is setting preconditioner_max_skew=1.

Installation

pip install kron-torch

Basic Usage (Kron)

Kron schedules the preconditioner update probability by default to start at 1.0 and anneal to 0.03 at the beginning of training, so training will be slightly slower at the start but will speed up by around 4k steps.

For basic usage, use kron optimizer like any other pytorch optimizer:

from kron_torch import Kron

optimizer = Kron(params)

optimizer.zero_grad()
loss.backward()
optimizer.step()

Basic hyperparameters:

TLDR: Start with a learning rate around 3x smaller than adam's. There is no b2 or epsilon.

These next 3 settings control whether a dimension's preconditioner is diagonal or triangular. For example, for a layer with shape (256, 128), triagular preconditioners would be shapes (256, 256) and (128, 128), and diagonal preconditioners would be shapes (256,) and (128,). Depending on how these settings are chosen, kron can balance between memory/speed and effectiveness. Defaults lead to most precoditioners being triangular except for 1-dimensional layers and very large dimensions.

max_size_triangular: Any dimension with size above this value will have a diagonal preconditioner.

min_ndim_triangular: Any tensor with less than this number of dims will have all diagonal preconditioners. Default is 2, so single-dim layers like bias and scale will use diagonal preconditioners.

memory_save_mode: Can be None, 'smart_one_diag', 'one_diag', or 'all_diag'. None is default and lets all preconditioners be triangular. 'smart_one_diag' sets the largest dim to diagonal only if it's larger than the second largest dim (if it stands out). 'one_diag' sets the largest or last dim per layer as diagonal using np.argsort(shape)[::-1][0]. 'all_diag' sets all preconditioners to be diagonal.

preconditioner_update_probability: Preconditioner update probability uses a schedule by default that works well for most cases. It anneals from 1 to 0.03 at the beginning of training, so training will be slightly slower at the start but will speed up by around 4k steps. PSGD generally benefits from more preconditioner updates at the start of training, but once the preconditioner is learned it's okay to do them less often. An easy way to adjust update frequency is to define your own schedule using the precond_update_prob_schedule function in kron.py (just changing the min_prob value is easiest) and pass this into kron through the preconditioner_update_probability hyperparameter.

This is the default schedule defined in the precond_update_prob_schedule function at the top of kron.py:

Default Schedule

Resources

PSGD papers and resources listed from Xi-Lin's repo

  1. Xi-Lin Li. Preconditioned stochastic gradient descent, arXiv:1512.04202, 2015. (General ideas of PSGD, preconditioner fitting losses and Kronecker product preconditioners.)
  2. Xi-Lin Li. Preconditioner on matrix Lie group for SGD, arXiv:1809.10232, 2018. (Focus on preconditioners with the affine Lie group.)
  3. Xi-Lin Li. Black box Lie group preconditioners for SGD, arXiv:2211.04422, 2022. (Mainly about the LRA preconditioner. See these supplementary materials for detailed math derivations.)
  4. Xi-Lin Li. Stochastic Hessian fittings on Lie groups, arXiv:2402.11858, 2024. (Some theoretical works on the efficiency of PSGD. The Hessian fitting problem is shown to be strongly convex on set ${\rm GL}(n, \mathbb{R})/R_{\rm polar}$.)
  5. Omead Pooladzandi, Xi-Lin Li. Curvature-informed SGD via general purpose Lie-group preconditioners, arXiv:2402.04553, 2024. (Plenty of benchmark results and analyses for PSGD vs. other optimizers.)

License

CC BY 4.0

This work is licensed under a Creative Commons Attribution 4.0 International License.

2024 Evan Walters, Omead Pooladzandi, Xi-Lin Li

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kron_torch-0.3.1.tar.gz (22.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

kron_torch-0.3.1-py3-none-any.whl (26.2 kB view details)

Uploaded Python 3

File details

Details for the file kron_torch-0.3.1.tar.gz.

File metadata

  • Download URL: kron_torch-0.3.1.tar.gz
  • Upload date:
  • Size: 22.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.9

File hashes

Hashes for kron_torch-0.3.1.tar.gz
Algorithm Hash digest
SHA256 351b3feec45ba79f6d8419eb1df4db93c5e8a3afc24905b880d48d3c19ec855f
MD5 5f3dbda98b474c353d3c60e3e7ffa50d
BLAKE2b-256 ab17b08ea2af43d6185d3f73e319c22405708b8dc45f349050d0801f61c1998e

See more details on using hashes here.

File details

Details for the file kron_torch-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: kron_torch-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 26.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.9

File hashes

Hashes for kron_torch-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5a9c040a0853b133d2bd1272247be4f6492d0aa519332e292661c04967fc3891
MD5 f5addd89a3c04b587bb5889252f114d6
BLAKE2b-256 6b8235320d9892248b3946de90981f65300e7acb995ceb5d36db5848391abc00

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page