Skip to main content

JAX implementation of the Preconditioned Stochastic Gradient Descent (PSGD) optimizer.

Project description

psgdax: Preconditioned Stochastic Gradient Descent in JAX

psgdax provides a JAX implementation of the Preconditioned Stochastic Gradient Descent (PSGD) optimizer, compatible with the Optax ecosystem. This library implements the Kronecker-product preconditioner (Kron) based on the theory of Hessian fitting in Lie groups.

This implementation translates the PyTorch reference implementation by Xi-Lin Li.

Mathematical Background

PSGD reformulates preconditioner estimation as a strongly convex optimization problem on Lie groups. Unlike standard quasi-Newton methods (e.g., BFGS, KFAC) that operate in Euclidean space or the manifold of SPD matrices, PSGD updates the preconditioner $Q$ (where $P = Q^T Q$) using multiplicative updates that avoid explicit matrix inversion.

The update rule minimizes the criterion $E[\delta g^T P \delta g + \delta \theta^T P^{-1} \delta \theta]$, ensuring the preconditioner approximates the Hessian or the inverse covariance of gradients.

Installation

pip install psgdax

Usage

Basic Usage with Optax

psgdax follows the optax.GradientTransformation interface. It can be chained with other transformations, though the provided kron alias handles standard scheduling, weight decay, and scale-by-learning-rate chains automatically.

import jax
import jax.numpy as jnp
from psgdax import kron

# Define parameters
params = {
    'w': jnp.zeros((128, 128)),
    'b': jnp.zeros((128,))
}

# Initialize optimizer
# The default mode is Q0.5EQ1.5 (Procrustes-regularized update)
optimizer = kron(
    learning_rate=1e-3,
    b1=0.9,                 # Momentum
    preconditioner_lr=0.1,  # Learning rate for the preconditioner Q
    whiten_grad=True        # Whiten gradients (True) or Momentum (False)
)

opt_state = optimizer.init(params)

@jax.jit
def step(params, opt_state, grads):
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = jax.tree.map(lambda p, u: p + u, params, updates)
    return new_params, new_opt_state

Advanced Usage: Scanned Layers

For deep architectures (e.g., Transformers) implemented via jax.lax.scan, psgdax supports explicit handling of scanned layers to prevent unrolling computation graphs. This significantly improves compilation time and memory efficiency.

import jax
from psgdax import kron

# Assume a boolean pytree mask where True indicates a scanned parameter
# matching the structure of 'params'
scanned_layers_mask = ... 

optimizer = kron(
    learning_rate=3e-4,
    scanned_layers=scanned_layers_mask,
    lax_map_scanned_layers=True, # Use lax.map for preconditioner updates
    lax_map_batch_size=8
)

Configuration

Preconditioner Modes

The geometry of the preconditioner update $dQ$ is controlled via preconditioner_mode.

Mode Formula Description
Q0.5EQ1.5 $dQ = Q^{0.5} \mathcal{E} Q^{1.5}$ Recommended. Uses an online orthogonal Procrustes solver to keep $Q$ approximately SPD. Numerically stable for low precision.
EQ $dQ = \mathcal{E} Q$ The original PSGD update (triangular group). Requires triangular solves.
QUAD Quadratic Form Ensures $Q$ remains symmetric positive definite via quadratic form updates.

Hyperparameters

  • preconditioner_lr: The learning rate for $Q$. Recommended range $[0.01, 0.1]$.
  • preconditioner_update_probability: Probability of updating $Q$ at each step. Can be a float or a schedule callable. Annealing this probability can reduce overhead.
  • max_size_triangular: Dimensions larger than this will default to diagonal preconditioners to save memory.
  • memory_save_mode:
    • None: Standard behavior.
    • 'one_diag': Forces the largest dimension of a tensor to be diagonal.
    • 'all_diag': Forces all dimensions to be diagonal (similar to Shampoo without blocks).
  • whiten_grad:
    • True: The preconditioner whitens the raw gradient.
    • False: The preconditioner whitens the momentum vector. Requires b1 > 0. Note: If False, the learning rate typically needs to be reduced by a factor of $\sqrt{\frac{1+\beta}{1-\beta}}$.

Precision

JAX defaults to bfloat16 on TPUs or float32 depending on configuration. PSGD is sensitive to precision during the preconditioner update.

  • precond_update_precision: Defaults to "tensorfloat32".
  • precond_grads_precision: Precision for the application of the preconditioner to the gradient.

Implementation Details

Kronecker Decomposition

For a tensor parameter of shape $(n_1, n_2, \dots)$, PSGD approximates the Hessian inverse as a Kronecker product of smaller matrices $Q = Q_1 \otimes Q_2 \dots$.

  • Dimensions where $n_i >$ max_size or $n_i^2 >$ max_skew $\cdot$ numel are approximated via diagonal matrices.
  • Dimensions fitting the criteria utilize full dense matrices.

Eigenvalue Bounds

To ensure numerical stability without expensive eigenvalue decompositions, this implementation utilizes randomized lower-bound estimators for spectral norms (_norm_lower_bound_spd and _norm_lower_bound_skh) during the update of Lipschitz constants.

Citations

This library is a translation of the work by Xi-Lin Li. If you use this optimizer, please cite the original papers:

@article{li2015preconditioned,
  title={Preconditioned Stochastic Gradient Descent},
  author={Li, Xi-Lin},
  journal={arXiv preprint arXiv:1512.04202},
  year={2015}
}

@article{li2018preconditioner,
  title={Preconditioner on Matrix Lie Group for SGD},
  author={Li, Xi-Lin},
  journal={arXiv preprint arXiv:1809.10232},
  year={2018}
}

@article{li2024stochastic,
  title={Stochastic Hessian Fittings with Lie Groups},
  author={Li, Xi-Lin},
  journal={arXiv preprint arXiv:2402.11858},
  year={2024}
}

Acknowledgments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

psgdax-0.1.1.tar.gz (13.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

psgdax-0.1.1-py3-none-any.whl (14.4 kB view details)

Uploaded Python 3

File details

Details for the file psgdax-0.1.1.tar.gz.

File metadata

  • Download URL: psgdax-0.1.1.tar.gz
  • Upload date:
  • Size: 13.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.13 {"installer":{"name":"uv","version":"0.9.13"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"NixOS","version":"26.05","id":"yarara","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for psgdax-0.1.1.tar.gz
Algorithm Hash digest
SHA256 c425f432ad88fcc7f9eb489d85845ee6f7a2cde2db5030d878e34a8373b2106c
MD5 e6edee57626093f871fee28c1f1366a7
BLAKE2b-256 356317973328de1ca9154ee8655f9004a0ae7182f6889a05cb26456d4fb6f20f

See more details on using hashes here.

File details

Details for the file psgdax-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: psgdax-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 14.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.13 {"installer":{"name":"uv","version":"0.9.13"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"NixOS","version":"26.05","id":"yarara","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for psgdax-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0bd335ddca60f45a26860cfd120e21775e3a1741a5fadb25af03b3fb7efd0309
MD5 46f1f3bad6dd29a102095ea3a3f09451
BLAKE2b-256 645ab8398001a0fce5f7746b89d426f870162822ba964d00d21edf872e106f6b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page