JAX implementation of the Preconditioned Stochastic Gradient Descent (PSGD) optimizer.
Project description
psgdax: Preconditioned Stochastic Gradient Descent in JAX
psgdax provides a JAX implementation of the Preconditioned Stochastic Gradient Descent (PSGD) optimizer, compatible
with the Optax ecosystem. This library implements the Kronecker-product
preconditioner (Kron) based on the theory of Hessian fitting in Lie groups.
This implementation translates the PyTorch reference implementation by Xi-Lin Li.
Mathematical Background
PSGD reformulates preconditioner estimation as a strongly convex optimization problem on Lie groups. Unlike standard quasi-Newton methods (e.g., BFGS, KFAC) that operate in Euclidean space or the manifold of SPD matrices, PSGD updates the preconditioner $Q$ (where $P = Q^T Q$) using multiplicative updates that avoid explicit matrix inversion.
The update rule minimizes the criterion $E[\delta g^T P \delta g + \delta \theta^T P^{-1} \delta \theta]$, ensuring the preconditioner approximates the Hessian or the inverse covariance of gradients.
Installation
pip install psgdax
Usage
Basic Usage with Optax
psgdax follows the optax.GradientTransformation interface. It can be chained with other transformations, though the
provided kron alias handles standard scheduling, weight decay, and scale-by-learning-rate chains automatically.
import jax
import jax.numpy as jnp
from psgdax import kron
# Define parameters
params = {
'w': jnp.zeros((128, 128)),
'b': jnp.zeros((128,))
}
# Initialize optimizer
# The default mode is Q0.5EQ1.5 (Procrustes-regularized update)
optimizer = kron(
learning_rate=1e-3,
b1=0.9, # Momentum
preconditioner_lr=0.1, # Learning rate for the preconditioner Q
whiten_grad=True # Whiten gradients (True) or Momentum (False)
)
opt_state = optimizer.init(params)
@jax.jit
def step(params, opt_state, grads):
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = jax.tree.map(lambda p, u: p + u, params, updates)
return new_params, new_opt_state
Advanced Usage: Scanned Layers
For deep architectures (e.g., Transformers) implemented via jax.lax.scan, psgdax supports explicit handling of
scanned layers to prevent unrolling computation graphs. This significantly improves compilation time and memory
efficiency.
import jax
from psgdax import kron
# Assume a boolean pytree mask where True indicates a scanned parameter
# matching the structure of 'params'
scanned_layers_mask = ...
optimizer = kron(
learning_rate=3e-4,
scanned_layers=scanned_layers_mask,
lax_map_scanned_layers=True, # Use lax.map for preconditioner updates
lax_map_batch_size=8
)
Configuration
Preconditioner Modes
The geometry of the preconditioner update $dQ$ is controlled via preconditioner_mode.
| Mode | Formula | Description |
|---|---|---|
Q0.5EQ1.5 |
$dQ = Q^{0.5} \mathcal{E} Q^{1.5}$ | Recommended. Uses an online orthogonal Procrustes solver to keep $Q$ approximately SPD. Numerically stable for low precision. |
EQ |
$dQ = \mathcal{E} Q$ | The original PSGD update (triangular group). Requires triangular solves. |
QUAD |
Quadratic Form | Ensures $Q$ remains symmetric positive definite via quadratic form updates. |
Hyperparameters
preconditioner_lr: The learning rate for $Q$. Recommended range $[0.01, 0.1]$.preconditioner_update_probability: Probability of updating $Q$ at each step. Can be a float or a schedule callable. Annealing this probability can reduce overhead.max_size_triangular: Dimensions larger than this will default to diagonal preconditioners to save memory.memory_save_mode:None: Standard behavior.'one_diag': Forces the largest dimension of a tensor to be diagonal.'all_diag': Forces all dimensions to be diagonal (similar to Shampoo without blocks).
whiten_grad:True: The preconditioner whitens the raw gradient.False: The preconditioner whitens the momentum vector. Requiresb1 > 0. Note: IfFalse, the learning rate typically needs to be reduced by a factor of $\sqrt{\frac{1+\beta}{1-\beta}}$.
Precision
JAX defaults to bfloat16 on TPUs or float32 depending on configuration. PSGD is sensitive to precision during the
preconditioner update.
precond_update_precision: Defaults to"tensorfloat32".precond_grads_precision: Precision for the application of the preconditioner to the gradient.
Implementation Details
Kronecker Decomposition
For a tensor parameter of shape $(n_1, n_2, \dots)$, PSGD approximates the Hessian inverse as a Kronecker product of smaller matrices $Q = Q_1 \otimes Q_2 \dots$.
- Dimensions where $n_i >$
max_sizeor $n_i^2 >$max_skew$\cdot$numelare approximated via diagonal matrices. - Dimensions fitting the criteria utilize full dense matrices.
Eigenvalue Bounds
To ensure numerical stability without expensive eigenvalue decompositions, this implementation utilizes randomized
lower-bound estimators for spectral norms (_norm_lower_bound_spd and _norm_lower_bound_skh) during the update of
Lipschitz constants.
Citations
This library is a translation of the work by Xi-Lin Li. If you use this optimizer, please cite the original papers:
@article{li2015preconditioned,
title={Preconditioned Stochastic Gradient Descent},
author={Li, Xi-Lin},
journal={arXiv preprint arXiv:1512.04202},
year={2015}
}
@article{li2018preconditioner,
title={Preconditioner on Matrix Lie Group for SGD},
author={Li, Xi-Lin},
journal={arXiv preprint arXiv:1809.10232},
year={2018}
}
@article{li2024stochastic,
title={Stochastic Hessian Fittings with Lie Groups},
author={Li, Xi-Lin},
journal={arXiv preprint arXiv:2402.11858},
year={2024}
}
Acknowledgments
- Xi-Lin Li: Author of the original PSGD algorithm and PyTorch implementation.
- Evanatyourservice: Author of the preliminary JAX port upon which this library improves.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file psgdax-0.1.0.tar.gz.
File metadata
- Download URL: psgdax-0.1.0.tar.gz
- Upload date:
- Size: 13.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.13 {"installer":{"name":"uv","version":"0.9.13"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"NixOS","version":"26.05","id":"yarara","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7df416ee4cc238ed5a674d79a2df25c0420ad69737820a6c8c647a7743f7e405
|
|
| MD5 |
28f49ca4378063eca5336ec63edf47b0
|
|
| BLAKE2b-256 |
14c8d6ef02aa7453c67eb9bbaf88af0990d9b2b3ab33af7d5a7ce1f8fb943e34
|
File details
Details for the file psgdax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: psgdax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 14.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.13 {"installer":{"name":"uv","version":"0.9.13"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"NixOS","version":"26.05","id":"yarara","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
27cb7cc11a34aaaf349ff35ae2fc174942091dbd5fb570a1e7949a7000d33b45
|
|
| MD5 |
4b85ee08956c5d47ad1b2bc3342bbc1e
|
|
| BLAKE2b-256 |
410c6747d0ca2696aaaae0820e066edc4f69a4d9a6285fd72b958178d4d44965
|