Skip to main content

GPU-accelerated matrix exponential and related functions for Apple MLX

Project description

mlx-expm

GPU-accelerated matrix exponential and related functions for Apple MLX.

MLX has no built-in expm. This package fills the gap.

Functions

Function Description Algorithm
expm(A) Matrix exponential Scaling & squaring + [13/13] Pade
expm_frechet(A, E) Frechet derivative of expm Block-triangular (Van Loan)
logm(A) Principal matrix logarithm Inverse scaling & squaring
sqrtm(A) Principal matrix square root Denman-Beavers iteration

Installation

pip install -e ".[dev]"

Usage

import mlx.core as mx
from mlx_expm import expm, logm, sqrtm

# Quantum time evolution: U = exp(-i H t)
H = mx.array([[0.0, 1.0], [1.0, 0.0]])  # Pauli X
t = 0.5
U = expm(-1j * H * t)

# Matrix logarithm (inverse of expm)
A = mx.array([[1.0, 2.0], [0.0, 3.0]])
L = logm(expm(A))  # recovers A

# Matrix square root
S = sqrtm(mx.array([[4.0, 0.0], [0.0, 9.0]]))  # [[2, 0], [0, 3]]

# Frechet derivative
A = mx.array([[1.0, 0.5], [0.0, 2.0]])
E = mx.array([[0.1, 0.0], [0.0, 0.1]])
expm_A, L = expm_frechet(A, E)

Applications

  • Quantum mechanics: Unitary evolution U = exp(-iHt)
  • Control theory: State transition matrix exp(At)
  • Differential equations: Matrix ODE solutions
  • Neural ODEs: Continuous-time dynamics
  • Lie groups: Exponential map on matrix Lie algebras

Algorithm Details

expm — Scaling and Squaring with [13/13] Pade

The same algorithm as scipy.linalg.expm (Higham 2005/2009):

  1. Scaling: Find s such that ||A / 2^s||_1 <= theta_13 = 5.37
  2. Pade: Evaluate the [13/13] rational approximant R_{13}(A/2^s)
  3. Squaring: Recover exp(A) = R_{13}^{2^s} via repeated squaring

This requires 13 matrix multiplications + 1 linear solve — all on MLX GPU.

logm — Inverse Scaling and Squaring

  1. Repeatedly compute square roots (via Denman-Beavers) until ||X - I|| is small
  2. Evaluate log(I + E) via degree-8 Taylor series
  3. Scale back by 2^s

sqrtm — Denman-Beavers Iteration

Quadratically convergent iteration:

  • Y_{k+1} = (Y_k + Z_k^{-1}) / 2
  • Z_{k+1} = (Z_k + Y_k^{-1}) / 2

Converges to Y -> A^{1/2}, Z -> A^{-1/2}.

Benchmark

python benchmark.py
python benchmark.py --sizes 32 64 128 256 512 --complex

Measured on Apple M1 Max (MLX 0.31.1, Python 3.11) vs scipy.linalg.expm:

n real speedup complex speedup
32 0.03x 0.01x
128 0.15x 0.26x
256 0.53x 1.96x
512 0.54x 1.33x
1024 2.08x -
2048 1.87x -

Crossover is around n=1024 real / n=256 complex; below that, SciPy's CPU LAPACK wins on dispatch overhead. Max abs error stays at the float32 noise floor (~1e-6 * n). See benchmark_results.md for full numbers and discussion.

Tests

pytest tests/ -v

References

  • Higham, Functions of Matrices, SIAM, 2008.
  • Al-Mohy & Higham, "A New Scaling and Squaring Algorithm for the Matrix Exponential", SIAM J. Matrix Anal. Appl. 31(3), 2009.
  • Al-Mohy & Higham, "Computing the Frechet Derivative of the Matrix Exponential", SIAM J. Matrix Anal. Appl. 30(4), 2009.

License

MIT — Sheng-Kai Huang

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_expm-0.1.1.tar.gz (11.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_expm-0.1.1-py3-none-any.whl (9.5 kB view details)

Uploaded Python 3

File details

Details for the file mlx_expm-0.1.1.tar.gz.

File metadata

  • Download URL: mlx_expm-0.1.1.tar.gz
  • Upload date:
  • Size: 11.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for mlx_expm-0.1.1.tar.gz
Algorithm Hash digest
SHA256 72d7a9244c8d0f05517a5193517f99209ea13dfaeb716b5dc9b015be1768e01b
MD5 dec141e1f729a815f774cdb947830330
BLAKE2b-256 4ec0f68631e4a2d81688e3de425646b909f9ef4d4ba8679e0f4f620ef8a36000

See more details on using hashes here.

File details

Details for the file mlx_expm-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: mlx_expm-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 9.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for mlx_expm-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 45b21675d9a0a0e28de92183bea24b2af3c12970e1997fc1b03afa14dbd46047
MD5 b270cc738ba7ed5a59d3b74e77eb6c39
BLAKE2b-256 07d8b5452a4f3d9b4edc529ad61187b7e7193208c457a16799f0fbc3203a6315

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page