Skip to main content

GPU-accelerated diffusion maps with tensor core kernels

Project description

FlashDiffusion

Tiled, memory-efficient diffusion maps eigensolver — never materialises the O(N²) kernel matrix.

What this enables

Task N limit (dense) N limit (FlashDiffusion)
Diffusion map eigenvectors ~20k (8 GB) ~10M (memory: O(N·tile))
Carré du Champ metric tensor ~10k ~1M
Schrödinger bridge / Doob h-transform ~10k ~1M
MD trajectory slow manifold ~5k frames ~500k frames

The primitive

from flashdiffusion import FlashDiffusion

# K(i,j) = exp(-beta * score_mod(xi, xj))
# with Coifman-Lafon alpha-normalisation and optional Doob h-transform
out = FlashDiffusion(X, X, V, beta=1.0, alpha=0.5, h=None)

FlashDiffusion(Q, K, V, beta, alpha, h) is the single kernel underlying everything:

  • DMAP mode (Q=K=X, V=eigvec guess): symmetric diffusion matvec for Lanczos
  • AMAP mode (Q≠K, asymmetric W): directed/NESS Markov operator (DAC §3)
  • Doob mode (h≠None): h-transform of any of the above (DAC §4)

Design

kernel.py        FlashDiffusion primitive — tiled O(N·tile) memory matvec
dmap.py          DiffusionMap class: fit() precomputes rscale; transform() calls kernel
eigensolver.py   Lanczos/LOBPCG with bf16→fp32→fp64 precision ramp; orthog in fp64
cdc.py           Carré du Champ operator Γ(f,g) and metric tensor M = ΣᵢGᵢGᵢᵀ
utils.py         Bandwidth selection (self-tuning, median), Nyström extension

Physics background

Transformer attention, diffusion maps, and magnetic Laplacians are three regimes of a single Markov geometry (Candanedo 2025). The EQ/NESS classification:

  • EQ (symmetric W): DMAP — reversible diffusion, detailed balance
  • NESS (asymmetric W): self-attention — directed, probability current ≠ 0
  • Doob deformation: h-transform preserves EQ class (Theorem 5.1, DAC)

The Coifman-Lafon α-normalisation is an exact Doob transform (Corollary 4.2, DAC).

FlashDiffusion SM120 — RTX 6000 / Consumer Blackwell

What's in this package

flashdiffusion/csrc/flash_diffusion_sm120.cu   ← kernel
flashdiffusion/kernel_cuda_sm120.py            ← Python wrapper
setup_sm120.py                                  ← build script
sm120_notebook.py                               ← notebook cells

SM120 vs SM80 kernel differences

Feature SM80 (A100) SM120 (RTX 5090)
MMA instruction mma.sync f16→f32 mma.sync f16→f32 (same)
Tile size 64×64 128×128
SMEM async __syncthreads cp.async pipeline
TMA no yes (not used yet)
TMEM no no (SM100 only)
Cluster shape 1×1×1 1×1×1 (no multicast)
WGMMA/UMMA no no (SM100 only)

The inner GEMM instruction is identical to SM80. The speedup comes from:

  1. Larger 128×128 tiles → better arithmetic intensity
  2. cp.async pipeline → hides GMEM→SMEM load latency
  3. Higher RTX 5090 clock speed

Build

# on a machine with RTX 5090
FLASHDIFFUSION_BUILD_CUDA=1 \
TORCH_CUDA_ARCH_LIST="12.0" \
python setup_sm120.py build_ext --inplace

Or in notebook (see sm120_notebook.py Cell 2).

Expected performance vs SM80

N=200k  SM80 A100  precompute: 0.4s   Lanczos: ~100s
N=200k  SM120 5090 precompute: ~0.3s  Lanczos: ~60s  (estimate)

RTX 5090 has higher memory bandwidth than A100 (1.8 TB/s vs 2.0 TB/s) but our scalar kernel is compute-bound not memory-bound. The cp.async pipeline and larger tiles are the main wins.

Next step: TiledMMA on SM120

The mma.sync.aligned.m16n8k16 instruction on SM120 uses the same CuTe atom as SM80: SM80_16x8x16_F32F16F16F32. Wiring TiledMMA replaces the scalar inner loop, giving ~10× speedup. This is flash_diffusion_sm120_v2.cu — to be added next.

CUDA roadmap

Current: NumPy reference (validates correctness, runs on CPU). Next: CuTe kernel — SM90 tensor cores, fused prepass+matvec, no online softmax needed (precomputed rscale eliminates the log-sum-exp recurrence that drives FA4 complexity).

prepass:  out_i = Σⱼ K(i,j) · wⱼ          # scalar reduction per row, one kernel
matvec:   out_i = rscaleᵢ · Σⱼ K(i,j) · rscaleⱼ · vⱼ  # vector, same tiling

Both kernels share the same tile loop; only the epilogue differs.

Installation

pip install -e ".[dev]"

Quick start

from flashdiffusion import DiffusionMap

dm = DiffusionMap(beta=1.0, alpha=0.5, n_components=8)
dm.fit(X)               # two prepass tiled reductions, O(N·tile) memory
coords = dm.transform() # Lanczos eigensolver, mixed precision

References

  • Coifman & Lafon (2006) — diffusion maps
  • Candanedo (2025) — diffusion-attention connection, bidivergence, Doob classification
  • Dao et al. (2024) — FlashAttention-4
  • Rohrdanz, Zheng, Clementi (2011) — diffusion maps for MD

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flashdiffusion-0.1.2.tar.gz (43.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flashdiffusion-0.1.2-py3-none-any.whl (50.2 kB view details)

Uploaded Python 3

File details

Details for the file flashdiffusion-0.1.2.tar.gz.

File metadata

  • Download URL: flashdiffusion-0.1.2.tar.gz
  • Upload date:
  • Size: 43.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for flashdiffusion-0.1.2.tar.gz
Algorithm Hash digest
SHA256 f596ba1fec2b357550a83af4b4ea06e3ddd81cb242cd9e00a88f9f78be1dac5d
MD5 c136493e9d072453bbcb150d9ab7d9dd
BLAKE2b-256 5b03e6f4171770a48b4e96789223d2ba974123296d5dfe2eb9bb70cb4f43c9e3

See more details on using hashes here.

Provenance

The following attestation bundles were made for flashdiffusion-0.1.2.tar.gz:

Publisher: publish.yml on sparsetrace/FlashDiffusion

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flashdiffusion-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: flashdiffusion-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 50.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for flashdiffusion-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2acb3cda5cecd0f2023c879993f40c12149eae0e4c1030ccd47f0ba5693540ae
MD5 3961fd194871fba4f52c5705ae4ead61
BLAKE2b-256 6e1c31f6df0e83a35026758df30ef7185073fd5e15c0f6d967d6c2610a3b2201

See more details on using hashes here.

Provenance

The following attestation bundles were made for flashdiffusion-0.1.2-py3-none-any.whl:

Publisher: publish.yml on sparsetrace/FlashDiffusion

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page