GPU-accelerated hypergeometric functions (2F1, 1F1, 0F1) for Apple MLX
Project description
mlx-hyp2f1
GPU-accelerated hypergeometric functions for Apple MLX.
JAX has hyp2f1 but PyTorch, CuPy, and MLX do not. This package fills the gap for MLX, providing vectorized implementations that run on Apple Silicon GPU.
Features
hyp2f1(a, b, c, z)-- Gauss hypergeometric function 2F1hyp1f1(a, b, z)-- Confluent (Kummer) hypergeometric function 1F1hyp0f1(b, z)-- Bessel-related hypergeometric function 0F1- Fully vectorized: compute over arrays of parameters simultaneously on GPU
- Fused Metal kernel for the Taylor series: 1 GPU dispatch instead of ~200 ops, ~5x faster than the pure-MLX op path on small/medium arrays
- Numerically stable via log-gamma Pochhammer symbols
- Analytic continuation for |z| >= 1 using linear transformation formulas
- Pure MLX -- no SciPy runtime dependency
Installation
pip install mlx-hyp2f1
Or from source:
git clone https://github.com/akaiHuang/mlx-hyp2f1.git
cd mlx-hyp2f1
pip install -e ".[dev]"
Usage
import mlx.core as mx
from mlx_hyp2f1 import hyp2f1, hyp1f1, hyp0f1
# Scalar
result = hyp2f1(0.5, 1.0, 1.5, 0.3)
# Vectorized over z
z = mx.linspace(0.0, 0.95, 100)
result = hyp2f1(0.5, 1.0, 1.5, z)
# Vectorized over all parameters
a = mx.array([0.5, 1.0, 2.0])
b = mx.array([1.0, 1.5, 0.5])
c = mx.array([1.5, 2.0, 3.0])
z = mx.array([0.3, 0.5, 0.8])
result = hyp2f1(a, b, c, z)
# Confluent hypergeometric (Kummer)
result = hyp1f1(1.0, 2.0, mx.linspace(-5.0, 5.0, 100))
# Bessel-related
result = hyp0f1(1.0, mx.linspace(-10.0, 10.0, 100))
Cosmological growth factor D(a)
import mlx.core as mx
from mlx_hyp2f1 import hyp2f1
def growth_factor(a, omega_m=0.3):
"""Linear growth factor D(a) for flat LCDM."""
omega_l = 1.0 - omega_m
x = -omega_l / omega_m * a**3
return a * hyp2f1(1.0/3.0, 1.0, 11.0/6.0, x)
Benchmark
python -m mlx_hyp2f1.benchmark
Measured on Apple M1 Max (MLX 0.31.1, SciPy 1.16.2, float32 vs float64 reference) -- with the fused Metal kernel:
| N | MLX (ms) | SciPy (ms) | Speedup | Max rel err |
|---|---|---|---|---|
| 100 | 8.10 | 0.06 | 0.01x | 1.92e-06 |
| 1,000 | 7.50 | 0.20 | 0.03x | 2.17e-06 |
| 10,000 | 7.23 | 1.73 | 0.24x | 2.57e-06 |
| 100,000 | 7.51 | 17.05 | 2.27x | 2.67e-06 |
| 1,000,000 | 22.56 | 167.65 | 7.43x | 2.95e-06 |
| 5,000,000 | 133.17 | 838.77 | 6.30x | 3.10e-06 |
| 10,000,000 | 289.64 | 1,712.08 | 5.91x | 3.10e-06 |
The fused Metal kernel cuts launch overhead from ~50 ms (200 op dispatches) to ~7 ms (1 dispatch). SciPy crossover is now at N≈100k, and at 1M+ points mlx-hyp2f1 is 5-7x faster than SciPy. Below ~50k points SciPy still wins because even one GPU launch costs more than SciPy's CPU loop.
See benchmark_results.md for full results, the
MLX→MLX comparison table (5x speedup vs the pre-fusion op path), 1F1/0F1
numbers, and accuracy details.
Testing
pip install -e ".[dev]"
pytest
Applications
- Cosmological growth factor D(a)
- Beta distribution CDF / PDF
- Angular momentum coupling coefficients
- Legendre / Jacobi / Gegenbauer polynomials
- Scattering amplitudes in quantum mechanics
License
MIT -- Sheng-Kai Huang
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mlx_hyp2f1-0.2.0.tar.gz.
File metadata
- Download URL: mlx_hyp2f1-0.2.0.tar.gz
- Upload date:
- Size: 13.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
001db3a4a898032cc2370f86141e6332089d4e986aff8797204db8a20661ae70
|
|
| MD5 |
02e1bd6454570ec6fbdc4178322685c5
|
|
| BLAKE2b-256 |
90fb6c4db62545745480c68987ce79c96ae7e31b39280f551b1c6cad89561c6c
|
File details
Details for the file mlx_hyp2f1-0.2.0-py3-none-any.whl.
File metadata
- Download URL: mlx_hyp2f1-0.2.0-py3-none-any.whl
- Upload date:
- Size: 11.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
904c91d21d279fbae889f606f819b4a01e1f91c889ec34589fa0e0720a946145
|
|
| MD5 |
31670b54bfdc95ada5a7edba110613f1
|
|
| BLAKE2b-256 |
aca2f65ed6d36432419bd5fa5cf36f07fd407356cf18ca2eb5db87217375afb9
|