Skip to main content

Fully-fused 3D Local Normalized Cross-Correlation loss (CUDA)

Project description

Fused LNCC

A fully-fused CUDA kernel for the 3D Local (squared) Normalized Cross-Correlation loss, the similarity metric used in deformable image registration and as a perceptual/structural loss.

It is ~3.5x faster and ~3x lighter on memory than the previous fastest differentiable 3D LNCC (FFDP, ICLR'26 Oral, the fused-kernel framework built on the FireANTs registration library), ~6-10x faster than MONAI, and ~16-18x faster than naive PyTorch, while producing matching gradients. These are loss-kernel numbers; the end-to-end registration speedup is smaller (see below). Verified on V100, A100, A40, and Blackwell.

Install

pip install -e . --no-build-isolation     # needs an NVIDIA GPU and a CUDA toolchain (nvcc, gcc)

This is a source build that compiles the kernel for every supported arch, so expect it to take a few minutes.

The wheel ships SASS for sm_70..90 plus sm_120 and JIT-compiles from PTX on newer GPUs, so it should run out of the box from Volta through Blackwell (V100/A100/A40/Blackwell are verified; see GPU support). Turing (sm_75) runs but its 64 KB shared memory limits it to k≤5. (CUDA only; no Apple/Metal or AMD/ROCm backend.)

Tested with PyTorch 2.3 to 2.12 and CUDA 11.8 to 13.x (it is a source build, so it compiles against your own torch + CUDA). One caveat: CUDA 13 dropped Volta, so a V100 needs a CUDA 12.x or older toolkit (the build selects arches automatically).

Usage

from fused_lncc import fused_lncc_loss

loss = fused_lncc_loss(pred, target, kernel_size=7)   # pred, target: (N,C,D,H,W) CUDA -> scalar in [0,1]
loss.backward()                                        # gradient flows through `pred`

It is a standard torch.autograd.Function, so it drops into any training loop and composes with other loss terms:

for x, target in loader:
    opt.zero_grad()
    pred = model(x)
    loss = F.l1_loss(pred, target) + 0.5 * fused_lncc_loss(pred, target, kernel_size=7)
    loss.backward()
    opt.step()

fp32 or bf16, odd kernel_size in {3,5,7,9}. Returns 1 - mean(local correlation) (lower is better). Forward values match MONAI's rectangular LNCC and a PyTorch reference to ~1e-7, and the analytic gradient matches to cosine >0.9999 (relative error <1e-3).

Notes: the gradient flows through pred only (target is treated as a fixed reference). Inputs must be fp32 or bf16, not fp16, so under torch.autocast(dtype=float16) cast first, e.g. fused_lncc_loss(pred.float(), target.float(), 7) (bf16 autocast works directly).

Performance

Speed and peak memory for the forward + backward step as the volume grows, against the common alternatives:

scaling

Headline at (2,16,128³), A40, fp32, k=7, forward + backward (time / peak-VRAM, lower is better):

fused_lncc FFDP (ICLR'26) MONAI sep + compile naive (PyTorch)
time / peak-VRAM (ms / GB) 24.5 / 1.07 85.7 / 3.22 162.5 / 7.21 81.1 / 4.03 432.5 / 3.49
vs fused_lncc (slower / VRAM) 1x / 1x 3.5x / 3.0x 6.6x / 6.7x 3.3x / 3.8x 18x / 3.3x

3.5x faster and 3x less memory than the SOTA, and the gap holds across V100/A100/A40/Blackwell. At high resolution the memory advantage becomes an OOM boundary: at 256³, fused_lncc runs in ~13 GB while the baselines need 30-61 GB and run out of memory on a 24 GB card.

What the benchmark measures. All contenders run the regime fused_lncc supports: rectangular box, gradient to pred only, and the exact backward. FFDP is run in the matching mode: only pred requires grad, so it takes its lean 3C-channel backward path, and with its exact gradient (not the cheaper use_ants_gradient approximation). It is apples-to-apples within that scope: not a claim over everything FFDP can do (Gaussian, large kernels, dual-image gradients, gigavoxel sharding).

End-to-end note. The numbers above are for the loss kernel in isolation. In a full registration pipeline the speedup depends on how much of each step the loss actually is: against the non-fused MONAI LNCC it stays large (~3x per iteration, since the loss dominates the step), but against FFDP (which already fuses the loss) it is modest (~1.1-1.15x, or ~1.3-1.4x with a matched exact gradient), at equal memory and equal registration quality, because the fused loss is only ~38% of an iteration. So fused_lncc helps most where the loss isn't already fused (most LNCC pipelines); against FFDP the gain is small but it stays more accurate (exact gradient) and simpler to deploy (standalone, no per-arch build).

Full benchmarks, the four-GPU comparison, the memory/OOM envelope, and end-to-end registration are in BENCHMARKS.md.

Scope vs FFDP/FireANTs

fused_lncc is a standalone loss that covers the common case and is fastest there; FFDP is the more general kernel, shipped as part of the FireANTs registration framework.

fused_lncc FFDP / FireANTs
Window rectangular box box + Gaussian
Kernel size odd k ∈ {3,5,7,9} arbitrary odd k
Gradient pred only (asymmetric) pred and target
Multi-GPU data-parallel (DDP, different volumes per GPU) data-parallel + grid-parallel (one gigavoxel volume sharded across GPUs)
Form factor a single torch.autograd.Function full registration framework (warper, optimizer, pyramid)

When the defaults match your problem (rectangular window, small kernel, registering a moving image to a fixed reference, on a single GPU), fused_lncc is the faster, lighter choice. Reach for FFDP/FireANTs if you need a Gaussian window, a large kernel, gradients to both images (symmetric/SyN registration, or atlas/template building where the template is also optimized), or multi-GPU sharding of a single volume too large for one card.

Why it's fast

LNCC needs, at every voxel, five local box-sums (Σp, Σt, Σp², Σt², Σpt) over a window, then a per-window correlation. Baselines run five separate convolutions and materialize every intermediate; even the prior fused kernel still routes the convolutions through cuDNN. We pull the whole computation into one shared-memory-tiled kernel, in both directions:

  • Forward: each block loads its tile once and computes all five statistics, the correlation, and the backward adjoints in a single pass, so the intermediates never touch global memory.
  • Backward: a single analytic kernel, dloss/dp = -(1/M)·(box(A) + 2p·box(B) + t·box(C)), with no autograd tape, which is where the ~3x memory saving comes from.
  • fp32 accumulation guards the variance's sum-of-squares cancellation; degenerate windows clamp to keep the loss finite and in [0,1].

GPU support

Verified on V100 (sm_70), A100 (sm_80), A40 (sm_86), and Blackwell RTX PRO 6000 (sm_120). The same wheel ran on each (no rebuild on V100/A100; PTX-JIT on Blackwell), with all tests and compute-sanitizer clean on fp32 and bf16. The full per-arch matrix and caveats (including Turing's shared-memory limit) are in BENCHMARKS.md.

Acknowledgments

  • fused-ssim (Rahul Goel): this project is directly inspired by it. SSIM and LNCC are the same shape of computation (local statistics via a separable windowed convolution plus a per-window formula), and the shared-memory-tiling and fused-backward design here mirrors fused-ssim's, applied to the box-window LNCC.
  • FFDP (Jena et al., ICLR'26 Oral): the prior fused 3D LNCC kernel and the analytic-backward idea; used here as the primary speed/memory baseline. FFDP is the fused-kernel framework built on the FireANTs registration library (Jena et al., Nature Communications).
  • MONAI LocalNormalizedCrossCorrelationLoss: the reference semantics we value-match against.

License

MIT, see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fused_lncc-0.1.0.tar.gz (19.2 kB view details)

Uploaded Source

File details

Details for the file fused_lncc-0.1.0.tar.gz.

File metadata

  • Download URL: fused_lncc-0.1.0.tar.gz
  • Upload date:
  • Size: 19.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.14

File hashes

Hashes for fused_lncc-0.1.0.tar.gz
Algorithm Hash digest
SHA256 fc937220d3cc6d7e8ba4ef50cd83be6f9d91d03a983dca23691d290805e8e508
MD5 8517166e84db71a4de2a966bb059e05e
BLAKE2b-256 f217a8bd4a139682ab8707bef92eb2901c9f3d33132629e1da9517308e2a7832

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page