An implementation of PSGD Kron optimizer in PyTorch.
Project description
PSGD Kron
For original PSGD repo, see psgd_torch.
For JAX version, see psgd_jax.
Implementation of PSGD Kron optimizer in PyTorch. PSGD is a second-order optimizer originally created by Xi-Lin Li that uses either a hessian-based or whitening-based (gg^T) preconditioner and lie groups to improve training convergence, generalization, and efficiency. I highly suggest taking a look at Xi-Lin's PSGD repo's readme linked to above for interesting details on how PSGD works and experiments using PSGD.
kron
:
The most versatile and easy-to-use PSGD optimizer is kron
, which uses a
Kronecker-factored preconditioner. It has less hyperparameters that need tuning than adam, and can
be a drop-in replacement for adam. It keeps a dim's preconditioner as either triangular
or diagonal based on max_size_triangular
and max_skew_triangular
. For example, for a layer
with shape (256, 128, 64), triangular preconditioners would be shapes (256, 256), (128, 128), and
(64, 64) and diagonal preconditioners would be shapes (256,), (128,), and (64,). Depending on how
these two settings are chosen, kron
can balance between memory/speed and performance (see below).
Installation
pip install kron-torch
Basic Usage (Kron)
Kron schedules the preconditioner update probability by default to start at 1.0 and anneal to 0.03 at the beginning of training, so training will be slightly slower at the start but will speed up to near adam's speed by around 3k steps.
For basic usage, use kron
optimizer like any other pytorch optimizer:
from kron_torch import Kron
optimizer = Kron(params)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Basic hyperparameters:
TLDR: Learning rate acts similarly to adam's, but can be set a little higher like 0.001 -> 0.003. Weight decay should be set lower than adam's, like 0.1 -> 0.03 or 0.01. There is no b2 or epsilon.
learning_rate
: Kron's learning rate acts similarly to adam's, but can withstand a higher
learning rate. Try setting 3x higher. If 0.001 was best for adam, try setting kron's to 0.003.
weight_decay
: PSGD does not rely on weight decay for generalization as much as adam, and too
high weight decay can hurt performance. Try setting 3-10x lower. If the best weight decay for
adam was 0.1, you can set kron's to 0.03 or 0.01.
max_size_triangular
: Anything above this value will have a diagonal preconditioner, anything
below will have a triangular preconditioner. So if you have a dim with size 16,384 that you want
to use a diagonal preconditioner for, set max_size_triangular
to something like 15,000. Default
is 8192.
max_skew_triangular
: Any tensor with skew above this value with make the larger dim diagonal.
For example, with the default value for max_skew_triangular
as 10, a bias layer of shape
(256,) would be diagonal because 256/1 > 10, and an embedding dim of shape (50000, 768) would
be (diag, tri) because 50000/768 is greater than 10. The default value of 10 usually makes
layers like bias, scale, and vocab embedding use diagonal with the rest as triangular.
Interesting note: Setting max_skew_triangular
to 0 will make all layers have (diag, tri)
preconditioners which uses slightly less memory than adam (and runs slightly faster). Setting
max_size_triangular
to 0 will make all layers have diagonal preconditioners which uses the least
memory and runs the fastest, but training might be slower.
preconditioner_update_probability
: Preconditioner update probability uses a schedule by default
that works well for most cases. It anneals from 1 to 0.03 at the beginning of training, so training
will be slightly slower at the start but will speed up to near adam's speed by around 3k steps.
See kron.py for more hyperparameter details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for kron_torch-0.1.7-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | dd33918693026ee7a725562fff31e4e925383ea2db2fdc2058c544bdfd01ada6 |
|
MD5 | 2597777ee400294c9b2f016358449d6b |
|
BLAKE2b-256 | 37a40f2f963006faaebe414e69a9281e4a877bd267e5c2d08de746962dc1dbdb |