Skip to main content

JAX implementation of the Preconditioned Stochastic Gradient Descent (PSGD) optimizer.

Project description

psgdax: Preconditioned Stochastic Gradient Descent in JAX

psgdax provides a JAX implementation of the Preconditioned Stochastic Gradient Descent (PSGD) optimizer, compatible with the Optax ecosystem. This library implements the Kronecker-product preconditioner (Kron) based on the theory of Hessian fitting in Lie groups.

This implementation translates the PyTorch reference implementation by Xi-Lin Li.

Mathematical Background

PSGD reformulates preconditioner estimation as a strongly convex optimization problem on Lie groups. Unlike standard quasi-Newton methods (e.g., BFGS, KFAC) that operate in Euclidean space or the manifold of SPD matrices, PSGD updates the preconditioner $Q$ (where $P = Q^T Q$) using multiplicative updates that avoid explicit matrix inversion.

The update rule minimizes the criterion $E[\delta g^T P \delta g + \delta \theta^T P^{-1} \delta \theta]$, ensuring the preconditioner approximates the Hessian or the inverse covariance of gradients.

Installation

pip install psgdax

Usage

Basic Usage with Optax

psgdax follows the optax.GradientTransformation interface. It can be chained with other transformations, though the provided kron alias handles standard scheduling, weight decay, and scale-by-learning-rate chains automatically.

import jax
import jax.numpy as jnp
from psgdax import kron

# Define parameters
params = {
    'w': jnp.zeros((128, 128)),
    'b': jnp.zeros((128,))
}

# Initialize optimizer
# The default mode is Q0.5EQ1.5 (Procrustes-regularized update)
optimizer = kron(
    learning_rate=1e-3,
    b1=0.9,                 # Momentum
    preconditioner_lr=0.1,  # Learning rate for the preconditioner Q
    whiten_grad=True        # Whiten gradients (True) or Momentum (False)
)

opt_state = optimizer.init(params)

@jax.jit
def step(params, opt_state, grads):
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = jax.tree.map(lambda p, u: p + u, params, updates)
    return new_params, new_opt_state

Advanced Usage: Scanned Layers

For deep architectures (e.g., Transformers) implemented via jax.lax.scan, psgdax supports explicit handling of scanned layers to prevent unrolling computation graphs. This significantly improves compilation time and memory efficiency.

import jax
from psgdax import kron

# Assume a boolean pytree mask where True indicates a scanned parameter
# matching the structure of 'params'
scanned_layers_mask = ... 

optimizer = kron(
    learning_rate=3e-4,
    scanned_layers=scanned_layers_mask,
    lax_map_scanned_layers=True, # Use lax.map for preconditioner updates
    lax_map_batch_size=8
)

Configuration

Preconditioner Modes

The geometry of the preconditioner update $dQ$ is controlled via preconditioner_mode.

Mode Formula Description
Q0.5EQ1.5 $dQ = Q^{0.5} \mathcal{E} Q^{1.5}$ Recommended. Uses an online orthogonal Procrustes solver to keep $Q$ approximately SPD. Numerically stable for low precision.
EQ $dQ = \mathcal{E} Q$ The original PSGD update (triangular group). Requires triangular solves.
QUAD Quadratic Form Ensures $Q$ remains symmetric positive definite via quadratic form updates.

Hyperparameters

  • preconditioner_lr: The learning rate for $Q$. Recommended range $[0.01, 0.1]$.
  • preconditioner_update_probability: Probability of updating $Q$ at each step. Can be a float or a schedule callable. Annealing this probability can reduce overhead.
  • max_size_triangular: Dimensions larger than this will default to diagonal preconditioners to save memory.
  • memory_save_mode:
    • None: Standard behavior.
    • 'one_diag': Forces the largest dimension of a tensor to be diagonal.
    • 'all_diag': Forces all dimensions to be diagonal (similar to Shampoo without blocks).
  • whiten_grad:
    • True: The preconditioner whitens the raw gradient.
    • False: The preconditioner whitens the momentum vector. Requires b1 > 0. Note: If False, the learning rate typically needs to be reduced by a factor of $\sqrt{\frac{1+\beta}{1-\beta}}$.

Precision

JAX defaults to bfloat16 on TPUs or float32 depending on configuration. PSGD is sensitive to precision during the preconditioner update.

  • precond_update_precision: Defaults to "tensorfloat32".
  • precond_grads_precision: Precision for the application of the preconditioner to the gradient.

Implementation Details

Kronecker Decomposition

For a tensor parameter of shape $(n_1, n_2, \dots)$, PSGD approximates the Hessian inverse as a Kronecker product of smaller matrices $Q = Q_1 \otimes Q_2 \dots$.

  • Dimensions where $n_i >$ max_size or $n_i^2 >$ max_skew $\cdot$ numel are approximated via diagonal matrices.
  • Dimensions fitting the criteria utilize full dense matrices.

Eigenvalue Bounds

To ensure numerical stability without expensive eigenvalue decompositions, this implementation utilizes randomized lower-bound estimators for spectral norms (_norm_lower_bound_spd and _norm_lower_bound_skh) during the update of Lipschitz constants.

Citations

This library is a translation of the work by Xi-Lin Li. If you use this optimizer, please cite the original papers:

@article{li2015preconditioned,
  title={Preconditioned Stochastic Gradient Descent},
  author={Li, Xi-Lin},
  journal={arXiv preprint arXiv:1512.04202},
  year={2015}
}

@article{li2018preconditioner,
  title={Preconditioner on Matrix Lie Group for SGD},
  author={Li, Xi-Lin},
  journal={arXiv preprint arXiv:1809.10232},
  year={2018}
}

@article{li2024stochastic,
  title={Stochastic Hessian Fittings with Lie Groups},
  author={Li, Xi-Lin},
  journal={arXiv preprint arXiv:2402.11858},
  year={2024}
}

Acknowledgments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

psgdax-0.1.0.tar.gz (13.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

psgdax-0.1.0-py3-none-any.whl (14.4 kB view details)

Uploaded Python 3

File details

Details for the file psgdax-0.1.0.tar.gz.

File metadata

  • Download URL: psgdax-0.1.0.tar.gz
  • Upload date:
  • Size: 13.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.13 {"installer":{"name":"uv","version":"0.9.13"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"NixOS","version":"26.05","id":"yarara","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for psgdax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 7df416ee4cc238ed5a674d79a2df25c0420ad69737820a6c8c647a7743f7e405
MD5 28f49ca4378063eca5336ec63edf47b0
BLAKE2b-256 14c8d6ef02aa7453c67eb9bbaf88af0990d9b2b3ab33af7d5a7ce1f8fb943e34

See more details on using hashes here.

File details

Details for the file psgdax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: psgdax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 14.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.13 {"installer":{"name":"uv","version":"0.9.13"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"NixOS","version":"26.05","id":"yarara","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for psgdax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 27cb7cc11a34aaaf349ff35ae2fc174942091dbd5fb570a1e7949a7000d33b45
MD5 4b85ee08956c5d47ad1b2bc3342bbc1e
BLAKE2b-256 410c6747d0ca2696aaaae0820e066edc4f69a4d9a6285fd72b958178d4d44965

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page