Skip to main content

GPU-accelerated diffusion maps with tensor core kernels

Project description

FlashDiffusion

Tiled, memory-efficient diffusion maps eigensolver — never materialises the O(N²) kernel matrix.

What this enables

Task N limit (dense) N limit (FlashDiffusion)
Diffusion map eigenvectors ~20k (8 GB) ~10M (memory: O(N·tile))
Carré du Champ metric tensor ~10k ~1M
Schrödinger bridge / Doob h-transform ~10k ~1M
MD trajectory slow manifold ~5k frames ~500k frames

The primitive

from flashdiffusion import FlashDiffusion

# K(i,j) = exp(-beta * score_mod(xi, xj))
# with Coifman-Lafon alpha-normalisation and optional Doob h-transform
out = FlashDiffusion(X, X, V, beta=1.0, alpha=0.5, h=None)

FlashDiffusion(Q, K, V, beta, alpha, h) is the single kernel underlying everything:

  • DMAP mode (Q=K=X, V=eigvec guess): symmetric diffusion matvec for Lanczos
  • AMAP mode (Q≠K, asymmetric W): directed/NESS Markov operator (DAC §3)
  • Doob mode (h≠None): h-transform of any of the above (DAC §4)

Design

kernel.py        FlashDiffusion primitive — tiled O(N·tile) memory matvec
dmap.py          DiffusionMap class: fit() precomputes rscale; transform() calls kernel
eigensolver.py   Lanczos/LOBPCG with bf16→fp32→fp64 precision ramp; orthog in fp64
cdc.py           Carré du Champ operator Γ(f,g) and metric tensor M = ΣᵢGᵢGᵢᵀ
utils.py         Bandwidth selection (self-tuning, median), Nyström extension

Physics background

Transformer attention, diffusion maps, and magnetic Laplacians are three regimes of a single Markov geometry (Candanedo 2025). The EQ/NESS classification:

  • EQ (symmetric W): DMAP — reversible diffusion, detailed balance
  • NESS (asymmetric W): self-attention — directed, probability current ≠ 0
  • Doob deformation: h-transform preserves EQ class (Theorem 5.1, DAC)

The Coifman-Lafon α-normalisation is an exact Doob transform (Corollary 4.2, DAC).

FlashDiffusion SM120 — RTX 6000 / Consumer Blackwell

What's in this package

flashdiffusion/csrc/flash_diffusion_sm120.cu   ← kernel
flashdiffusion/kernel_cuda_sm120.py            ← Python wrapper
setup_sm120.py                                  ← build script
sm120_notebook.py                               ← notebook cells

SM120 vs SM80 kernel differences

Feature SM80 (A100) SM120 (RTX 5090)
MMA instruction mma.sync f16→f32 mma.sync f16→f32 (same)
Tile size 64×64 128×128
SMEM async __syncthreads cp.async pipeline
TMA no yes (not used yet)
TMEM no no (SM100 only)
Cluster shape 1×1×1 1×1×1 (no multicast)
WGMMA/UMMA no no (SM100 only)

The inner GEMM instruction is identical to SM80. The speedup comes from:

  1. Larger 128×128 tiles → better arithmetic intensity
  2. cp.async pipeline → hides GMEM→SMEM load latency
  3. Higher RTX 5090 clock speed

Build

# on a machine with RTX 5090
FLASHDIFFUSION_BUILD_CUDA=1 \
TORCH_CUDA_ARCH_LIST="12.0" \
python setup_sm120.py build_ext --inplace

Or in notebook (see sm120_notebook.py Cell 2).

Expected performance vs SM80

N=200k  SM80 A100  precompute: 0.4s   Lanczos: ~100s
N=200k  SM120 5090 precompute: ~0.3s  Lanczos: ~60s  (estimate)

RTX 5090 has higher memory bandwidth than A100 (1.8 TB/s vs 2.0 TB/s) but our scalar kernel is compute-bound not memory-bound. The cp.async pipeline and larger tiles are the main wins.

Next step: TiledMMA on SM120

The mma.sync.aligned.m16n8k16 instruction on SM120 uses the same CuTe atom as SM80: SM80_16x8x16_F32F16F16F32. Wiring TiledMMA replaces the scalar inner loop, giving ~10× speedup. This is flash_diffusion_sm120_v2.cu — to be added next.

CUDA roadmap

Current: NumPy reference (validates correctness, runs on CPU). Next: CuTe kernel — SM90 tensor cores, fused prepass+matvec, no online softmax needed (precomputed rscale eliminates the log-sum-exp recurrence that drives FA4 complexity).

prepass:  out_i = Σⱼ K(i,j) · wⱼ          # scalar reduction per row, one kernel
matvec:   out_i = rscaleᵢ · Σⱼ K(i,j) · rscaleⱼ · vⱼ  # vector, same tiling

Both kernels share the same tile loop; only the epilogue differs.

Installation

pip install -e ".[dev]"

Quick start

from flashdiffusion import DiffusionMap

dm = DiffusionMap(beta=1.0, alpha=0.5, n_components=8)
dm.fit(X)               # two prepass tiled reductions, O(N·tile) memory
coords = dm.transform() # Lanczos eigensolver, mixed precision

References

  • Coifman & Lafon (2006) — diffusion maps
  • Candanedo (2025) — diffusion-attention connection, bidivergence, Doob classification
  • Dao et al. (2024) — FlashAttention-4
  • Rohrdanz, Zheng, Clementi (2011) — diffusion maps for MD

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flashdiffusion-0.1.0.tar.gz (33.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flashdiffusion-0.1.0-py3-none-any.whl (32.0 kB view details)

Uploaded Python 3

File details

Details for the file flashdiffusion-0.1.0.tar.gz.

File metadata

  • Download URL: flashdiffusion-0.1.0.tar.gz
  • Upload date:
  • Size: 33.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for flashdiffusion-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f6c6a5920b2b28000f26a4f7685f86a918b33e94b380be74f15a3d8f081aba1f
MD5 2bf38bf1988e5771efbe7c66a0bace87
BLAKE2b-256 6255e022a395158f78da51c66229bddc80fe158a30cd8e6a869812d1a4fa96c0

See more details on using hashes here.

Provenance

The following attestation bundles were made for flashdiffusion-0.1.0.tar.gz:

Publisher: publish.yml on sparsetrace/FlashDiffusion

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flashdiffusion-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: flashdiffusion-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 32.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for flashdiffusion-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a31a7ceac8bc8a98f08f22c1f5585801ab5d4de70c456e6abe86078b6ac5f1df
MD5 e92c96f262a2ea97da57fac59e0f2ce8
BLAKE2b-256 c8c838cf9d3d676262f8a83ea122ac4758c46515454e83b6ec30bab8d80f57de

See more details on using hashes here.

Provenance

The following attestation bundles were made for flashdiffusion-0.1.0-py3-none-any.whl:

Publisher: publish.yml on sparsetrace/FlashDiffusion

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page