Skip to main content

GPU-accelerated Fisher Information Matrix computation on Apple Silicon via MLX

Project description

mlx-fisher

GPU-accelerated Fisher Information Matrix computation on Apple Silicon via MLX.

Author: Sheng-Kai Huang (akai@fawstudio.com)

Features

  • Fisher Information Matrix from log-likelihood, model predictions, or samples
  • CMB C_l Fisher matrix for cosmological parameter estimation
  • KL divergence (classical) and quantum relative entropy D(rho||sigma)
  • Natural gradient descent optimizer with online Fisher estimation
  • All matrix operations (eigendecomposition, matrix multiply, log) on Apple GPU

Installation

pip install -e .

Requires Python 3.10+ and Apple Silicon (M1/M2/M3/M4).

Quick Start

import mlx.core as mx
from mlx_fisher import FisherMatrix, kl_divergence, quantum_relative_entropy

# --- Gaussian Fisher matrix from a model ---
x = mx.linspace(-3.0, 3.0, 1000)

def model(theta):
    return theta[0] * x**2 + theta[1] * x + theta[2]

theta0 = mx.array([1.0, -0.5, 2.0])
sigma = mx.ones((1000,)) * 0.5

F = FisherMatrix.from_model(model, theta0, sigma)
print(F.marginal_errors())   # 1-sigma errors on each parameter

# --- KL divergence ---
p = mx.array([0.4, 0.3, 0.2, 0.1])
q = mx.array([0.25, 0.25, 0.25, 0.25])
print(kl_divergence(p, q))   # D_KL(p || q)

# --- Quantum relative entropy ---
d = 4
rho = mx.zeros((d, d)); rho[0, 0] = 1.0       # pure state |0><0|
sigma_dm = mx.eye(d) / d                        # maximally mixed
print(quantum_relative_entropy(rho, sigma_dm))   # = ln(d)

CMB Fisher Matrix

from mlx_fisher import fisher_matrix_cl

def cl_fn(theta):
    """Map cosmological parameters to C_l power spectrum."""
    # Your Boltzmann solver here (e.g., CLASS wrapper)
    ...

theta_fid = mx.array([0.022, 0.12, 0.06, 0.96, 3.04, 67.4])
F = fisher_matrix_cl(cl_fn, theta_fid, f_sky=0.7, l_min=2, l_max=2500)
print(F.marginal_errors())

Natural Gradient Descent

from mlx_fisher import NaturalGradientOptimizer

opt = NaturalGradientOptimizer(lr=1e-2, damping=1e-4)

for step in range(100):
    grad = compute_gradient(theta)
    theta = opt.step(theta, grad, fisher_estimator=compute_fisher)

API Reference

FisherMatrix

  • FisherMatrix.from_model(model_fn, theta, sigma) -- Gaussian Fisher matrix
  • FisherMatrix.from_loglikelihood(log_lik, theta) -- from log-likelihood function
  • FisherMatrix.from_samples(log_prob, theta, samples) -- empirical Fisher
  • .inverse(reg=0.0) -- covariance matrix (regularised inversion)
  • .marginal_errors(reg=0.0) -- 1-sigma marginal errors
  • .eigenvalues() -- eigenvalue spectrum
  • .condition_number() -- matrix condition number

fisher_matrix_cl(cl_fn, theta, f_sky, noise_cl, l_min, l_max)

CMB power spectrum Fisher matrix with cosmic variance.

kl_divergence(p, q)

Classical KL divergence D_KL(p || q).

quantum_relative_entropy(rho, sigma)

Quantum relative entropy D(rho || sigma) = Tr[rho(ln rho - ln sigma)].

NaturalGradientOptimizer(lr, damping, fisher_update_interval, ema_decay)

Natural gradient descent: theta_new = theta - lr * F^{-1} @ grad.

Benchmarks (M1 Max)

Operation Scale MLX (ms) NumPy (ms) Speedup
KL divergence 1M bins 0.52 5.22 10x
KL divergence 10M bins 2.23 51.40 23x
Eigendecomposition 512x512 17.25 58.07 3.4x
Matrix multiply (Fisher) 32768x512 4.33 169.96 39x

See benchmark_results.md for full results.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_fisher-0.1.0.tar.gz (16.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_fisher-0.1.0-py3-none-any.whl (13.0 kB view details)

Uploaded Python 3

File details

Details for the file mlx_fisher-0.1.0.tar.gz.

File metadata

  • Download URL: mlx_fisher-0.1.0.tar.gz
  • Upload date:
  • Size: 16.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for mlx_fisher-0.1.0.tar.gz
Algorithm Hash digest
SHA256 707ca46fbd3f1c3567fcf97dbe9d479b620c0d4b41f8d936cc260fc9be969823
MD5 7bd784b27a6ff9333f456f22561a9004
BLAKE2b-256 b51972d475d0e48c962d97219903f817d497eb607281c9be0cffbc5a957bb98a

See more details on using hashes here.

File details

Details for the file mlx_fisher-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: mlx_fisher-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 13.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for mlx_fisher-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e8dad6d6da66d7ab5257a81d41275f1474a7aff7bb55d1f6a25828a71aef8338
MD5 4ed399765c622b665292fb480ad60708
BLAKE2b-256 6ca1d295f3d64acd4377caef369b36a2bf8a1d4a507ee2034c991dc9d70a7a6c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page