Skip to main content

Differentiable spectral operators on symmetric positive definite (SPD) matrices in PyTorch

Project description

torchspd 🔥📐

Tiny toolkit of spectral operators on symmetric positive definite (SPD) matrices in PyTorch. It gives you differentiable sqrtm, invsqrtm, logm, powm, expm, a PSD projection, and a generic apply_quad for custom functions. Everything is batched, works with autograd (first order), and is written to be numerically stable near repeated or tiny eigenvalues.

pip install torchspd

Usage

import torch, torchspd as spd

# SPD input
n = 8
A = torch.randn(n, n)
A = A @ A.T + n * torch.eye(n)

R = spd.sqrtm(A)     # Matrix square root
W = spd.invsqrtm(A)  # Inverse square root
L = spd.logm(A)      # Logarithm
A_back = spd.expm(L) # Exponential
B = spd.powm(A, 0.3) # Fractional power

# Projection onto the PSD cone
X = torch.randn(n, n)
X = 0.5 * (X + X.T)
P = spd.proj_psd(X)

# Reuse eigenpairs
Y, (L, V) = spd.logm(A, return_eig=True)
S = spd.sqrtm(A, eig=(L, V))
P = spd.powm(A, 0.3, eig=(L, V))

More generally, apply_quad lets you define your own spectral function $f$ and its derivative $f'$, and computes $F(A)=V f(\Lambda) V^\top$ differentiably w.r.t. $A=V\Lambda V^\top$.

def f(x): return torch.log1p(x)
def df(x): return 1 / (1 + x)

Y = spd.apply_quad(A, f, df)

apply_quad uses Gauss-Legendre quadrature near close eigenvalues. This is approximate: for important cases prefer explicit formulas like in sqrtm, invsqrtm... (see derivation guide below).

Calculations

Let $A=V\Lambda V^\top$ in $\mathbb{R}^{d\times d}$ be SPD, with $\Lambda=\mathrm{diag}(\lambda_1,\ldots,\lambda_n)$ and $VV^\top=I_d$, and let $f:\mathbb{R}_+^*\rightarrow\mathbb{R}$ be continuously differentiable. We define $F(A)=Vf(\Lambda)V^\top$, where $f(\Lambda)=\mathrm{diag}(f(\lambda_1),\ldots,f(\lambda_n))$. This function is well-defined on the set of SPD matrices, as it does not depend on the choice of $V$.

By Daleckii-Krein, the Fréchet derivative of $f(X)$ at $X=A$, applied to the symmetric perturbation $H$, verifies $$\mathrm{d}F_A(H)=V\left(G\circ (V^\top HV)\right)V^\top,$$ where $\circ$ denotes the coordinatewise matrix product, and $$G_{ij} = \frac{f(\lambda_i)-f(\lambda_j)}{\lambda_i-\lambda_j} \quad (\text{with } G_{ii}=f'(\lambda_i)).$$

Here are the values of $G_{ij}$ in common special cases: they are used in the implementation. We note $\delta=\frac{\lambda_i-\lambda_j}{\lambda_j}$ and $\mathrm{sinhc}(x)=\frac{\sinh(x)}{x}$ ($\mathrm{sinhc}(0)=1$).

Function $f(x)$ Off-diagonal terms $G_{ij}, i\neq j$ Diagonal terms $G_{ii}$
$f(x) = \sqrt{x}$ $\dfrac{1}{\sqrt{\lambda_i} + \sqrt{\lambda_j}}$ $\dfrac{1}{2\sqrt{\lambda_i}}$
$f(x) = 1/\sqrt{x}$ $-\dfrac{1}{\sqrt{\lambda_i}\sqrt{\lambda_j}(\sqrt{\lambda_i}+\sqrt{\lambda_j})}$ $-\dfrac{1}{2\lambda_i^{3/2}}$
$f(x) = \log x$ $\dfrac{1}{\lambda_j}\dfrac{\log(1+\delta)}{\delta}$ $\dfrac{1}{\lambda_i}$
$f(x) = e^{x}$ $e^{(\lambda_i+\lambda_j)/2}\mathrm{sinhc}!\left(\tfrac{\lambda_i-\lambda_j}{2}\right)$ $e^{\lambda_i}$
$f(x) = x^p,; p\in\mathbb{R}$ $\lambda_j^{p-1}\dfrac{(1-\delta)^{p}-1}{\delta}$ $p\lambda_i^{p-1}$

We also use the Taylor expansions of these formulas when $\lambda_i \approx \lambda_j$, replacing divided differences by series in $\delta = (\lambda_i-\lambda_j)/\lambda_j$ to avoid numerical cancellation.

For generic $f$ and $i\ne j$, we rely on the fact that $$G_{ij}=\int_{0}^{1}f'\left((1-t)\lambda_{j}+t\lambda_{i}\right)\text{d}t.$$ We currently approximate this integral using the Gauss-Legendre rule.

References

  • Functions of Matrices: Theory and Computation (Higham, 2008).
  • Improved Inverse Scaling and Squaring Algorithms for the Matrix Logarithm (Al-Mohy & Higham, 2012).
  • A Formula for the Fréchet Derivative of a Generalized Matrix Function (Noferini, 2016).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchspd-0.1.2.tar.gz (9.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchspd-0.1.2-py3-none-any.whl (7.8 kB view details)

Uploaded Python 3

File details

Details for the file torchspd-0.1.2.tar.gz.

File metadata

  • Download URL: torchspd-0.1.2.tar.gz
  • Upload date:
  • Size: 9.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for torchspd-0.1.2.tar.gz
Algorithm Hash digest
SHA256 b56974d7a0526ca1931fe4e4d853b14a1f693fbfb25672ebf810ef5a69dbb492
MD5 49cdf2c38b1eaa6459474ece82a1045b
BLAKE2b-256 b0c6bb8897eae708b4c0a62a9c0a1fda11d0286a5ab1d55875806942501c3d5e

See more details on using hashes here.

File details

Details for the file torchspd-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: torchspd-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 7.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for torchspd-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a7f6fe33fc5114b885ad5ec880f59ee9069af34904bf932addad2c9265647081
MD5 6ec518e739db0455898acab703b7b620
BLAKE2b-256 c1d927286de1b75b6e9293106ae4fcaada324b3344aaa7e38286ed26788d60f5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page