Skip to main content

GPU-accelerated hypergeometric functions (2F1, 1F1, 0F1) for Apple MLX

Project description

mlx-hyp2f1

GPU-accelerated hypergeometric functions for Apple MLX.

JAX has hyp2f1 but PyTorch, CuPy, and MLX do not. This package fills the gap for MLX, providing vectorized implementations that run on Apple Silicon GPU.

Features

  • hyp2f1(a, b, c, z) -- Gauss hypergeometric function 2F1
  • hyp1f1(a, b, z) -- Confluent (Kummer) hypergeometric function 1F1
  • hyp0f1(b, z) -- Bessel-related hypergeometric function 0F1
  • Fully vectorized: compute over arrays of parameters simultaneously on GPU
  • Fused Metal kernel for the Taylor series: 1 GPU dispatch instead of ~200 ops, ~5x faster than the pure-MLX op path on small/medium arrays
  • Numerically stable via log-gamma Pochhammer symbols
  • Analytic continuation for |z| >= 1 using linear transformation formulas
  • Pure MLX -- no SciPy runtime dependency

Installation

pip install mlx-hyp2f1

Or from source:

git clone https://github.com/akaiHuang/mlx-hyp2f1.git
cd mlx-hyp2f1
pip install -e ".[dev]"

Usage

import mlx.core as mx
from mlx_hyp2f1 import hyp2f1, hyp1f1, hyp0f1

# Scalar
result = hyp2f1(0.5, 1.0, 1.5, 0.3)

# Vectorized over z
z = mx.linspace(0.0, 0.95, 100)
result = hyp2f1(0.5, 1.0, 1.5, z)

# Vectorized over all parameters
a = mx.array([0.5, 1.0, 2.0])
b = mx.array([1.0, 1.5, 0.5])
c = mx.array([1.5, 2.0, 3.0])
z = mx.array([0.3, 0.5, 0.8])
result = hyp2f1(a, b, c, z)

# Confluent hypergeometric (Kummer)
result = hyp1f1(1.0, 2.0, mx.linspace(-5.0, 5.0, 100))

# Bessel-related
result = hyp0f1(1.0, mx.linspace(-10.0, 10.0, 100))

Cosmological growth factor D(a)

import mlx.core as mx
from mlx_hyp2f1 import hyp2f1

def growth_factor(a, omega_m=0.3):
    """Linear growth factor D(a) for flat LCDM."""
    omega_l = 1.0 - omega_m
    x = -omega_l / omega_m * a**3
    return a * hyp2f1(1.0/3.0, 1.0, 11.0/6.0, x)

Benchmark

python -m mlx_hyp2f1.benchmark

Measured on Apple M1 Max (MLX 0.31.1, SciPy 1.16.2, float32 vs float64 reference) -- with the fused Metal kernel:

N MLX (ms) SciPy (ms) Speedup Max rel err
100 8.10 0.06 0.01x 1.92e-06
1,000 7.50 0.20 0.03x 2.17e-06
10,000 7.23 1.73 0.24x 2.57e-06
100,000 7.51 17.05 2.27x 2.67e-06
1,000,000 22.56 167.65 7.43x 2.95e-06
5,000,000 133.17 838.77 6.30x 3.10e-06
10,000,000 289.64 1,712.08 5.91x 3.10e-06

The fused Metal kernel cuts launch overhead from ~50 ms (200 op dispatches) to ~7 ms (1 dispatch). SciPy crossover is now at N≈100k, and at 1M+ points mlx-hyp2f1 is 5-7x faster than SciPy. Below ~50k points SciPy still wins because even one GPU launch costs more than SciPy's CPU loop.

See benchmark_results.md for full results, the MLX→MLX comparison table (5x speedup vs the pre-fusion op path), 1F1/0F1 numbers, and accuracy details.

Testing

pip install -e ".[dev]"
pytest

Applications

  • Cosmological growth factor D(a)
  • Beta distribution CDF / PDF
  • Angular momentum coupling coefficients
  • Legendre / Jacobi / Gegenbauer polynomials
  • Scattering amplitudes in quantum mechanics

License

MIT -- Sheng-Kai Huang

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_hyp2f1-0.2.0.tar.gz (13.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_hyp2f1-0.2.0-py3-none-any.whl (11.3 kB view details)

Uploaded Python 3

File details

Details for the file mlx_hyp2f1-0.2.0.tar.gz.

File metadata

  • Download URL: mlx_hyp2f1-0.2.0.tar.gz
  • Upload date:
  • Size: 13.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for mlx_hyp2f1-0.2.0.tar.gz
Algorithm Hash digest
SHA256 001db3a4a898032cc2370f86141e6332089d4e986aff8797204db8a20661ae70
MD5 02e1bd6454570ec6fbdc4178322685c5
BLAKE2b-256 90fb6c4db62545745480c68987ce79c96ae7e31b39280f551b1c6cad89561c6c

See more details on using hashes here.

File details

Details for the file mlx_hyp2f1-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: mlx_hyp2f1-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 11.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for mlx_hyp2f1-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 904c91d21d279fbae889f606f819b4a01e1f91c889ec34589fa0e0720a946145
MD5 31670b54bfdc95ada5a7edba110613f1
BLAKE2b-256 aca2f65ed6d36432419bd5fa5cf36f07fd407356cf18ca2eb5db87217375afb9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page