Skip to main content

Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.

Project description

DCB — Differentiable Critical Bandwidth

PyPI License: Apache 2.0 Python 3.9+

A PyTorch package that makes Silverman's critical bandwidth test (1981) fully differentiable, enabling end-to-end gradient-based optimisation over the modal structure of continuous distributions.

Overview

h_crit is the minimum KDE bandwidth at which a distribution appears unimodal — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable step with a smooth surrogate, then uses the Implicit Function Theorem (IFT) to compute exact gradients through the root-finding step at O(1) memory cost.

import torch
from dcb import DCBLayer, TrainingLayer

X = torch.randn(10_000, requires_grad=True)  # 1D samples, any n from 5K to 1B
layer = DCBLayer()
h_crit = layer(X)       # differentiable scalar
h_crit.backward()       # exact IFT gradients

# For repeated training-loop use with warm-start bracket caching:
layer = TrainingLayer(warm_start=True)
for batch in dataloader:
    h = layer(batch)    # 1.8× faster after first call on CPU; ~10× on CUDA with compile=True

Installation

pip install diffcb

Or from source:

git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
cd differentiable-critical-bandwidth
pip install -e ".[dev]"

Accuracy vs R's bw.crit

Validated against R's multimode::bw.crit(data, mod0=1) (Hall & York 2001). Same-sample protocol (identical data fed to both Python and R):

n DCB error vs R Notes
5K–25K < 0.005% Direct-KDE path, zero histogram bias
100K 0.003% FFT histogram path, G=16384
1M 0.003% FFT path
10M 0.003% FFT path
100M+ < 0.01% Histogram-dominated; sketch available

Independent-sample error (~0.2–0.5%) reflects natural sampling variability (two RNGs), not algorithmic error. The 0.003% algorithmic error sits below R's own ~0.001% numerical noise floor.

Hardware Performance (v0.1.6)

n CPU (Apple M) MPS P100 GPU
10K 2,300 ms 1,400 ms 107 ms
50K 2,900 ms 1,700 ms 167 ms
100K 265 ms 248 ms 35 ms
1M 269 ms 189 ms 36 ms
10M 544 ms 44 ms

P100 speedup: 43–116× vs CPU. Peak 116× at n=50K (direct-KDE GPU parallelism).

Cumulative speedup vs v0.1.4 on CPU: 1.1× (100K), 1.7× (1M), 4.2× (10M).

API Reference

DCBLayer

DCBLayer(
    target_modes=1,        # target number of modes (default 1)
    use_fft=True,          # FFT path for n > 50K (default True)
    max_n_exact=None,      # sketch above this n (None = always exact)
    G_min=16384,           # minimum FFT histogram bins (accuracy ↑ with G)
    use_richardson=True,   # Richardson extrapolation on h_crit (30% accuracy gain)
    direct_n_max=25_000,   # use direct-KDE (no histogram) for n ≤ this
    direct_M=2048,         # direct-KDE evaluation grid size
    use_compile=False,     # infrastructure flag; use TrainingLayer for compile
    safe_backward=False,   # clamp IFT denominator near bifurcations
)

TrainingLayer (for ML training loops)

from dcb import TrainingLayer

layer = TrainingLayer(
    warm_start=True,    # cache h_prev; init bracket to [0.95h, 1.05h] → 1.8× CPU speedup
    compile=False,      # torch.compile opt-in (requires float32, Python ≤ 3.11 on CUDA)
    warm_margin=0.05,   # bracket half-width around cached h_crit
    **dcb_kwargs,       # any DCBLayer parameter
)
layer.reset_cache()     # call on distribution shift

Direct-KDE path (n ≤ 25K)

For small samples, DCB evaluates f′_h directly without histogram binning (O(n·M) per evaluation, zero binning bias). This is 3–4× slower on CPU but 80–96× faster than CPU on GPU.

# Force direct-KDE for all n (accuracy benchmark):
layer = DCBLayer(direct_n_max=float('inf'))

# Disable direct-KDE (speed benchmark):
layer = DCBLayer(direct_n_max=0)

Richardson extrapolation

By default (use_richardson=True), DCB runs a second bisection at G/2=8192 and combines: h̃ = (4·ĥ(G) − ĥ(G/2)) / 3, reducing error ~30%. On GPU this adds 38% overhead with <0.01% accuracy gain — consider use_richardson=False for GPU training loops.

Known Limitations

  • compile=True on MPS: blocked by float64 in _refine_hcrit fallback (fix in v0.1.7)
  • compile=True on CUDA with Python 3.12: requires torch ≥ 2.4 or Python ≤ 3.11
  • gradcheck with direct-KDE forward: forward (n ≤ 25K) uses exact KDE; backward uses smooth IFT surrogate — gradcheck will fail by design; use DCBLayer(direct_n_max=0) for gradcheck
  • n > 100M: requires streaming histogram (not yet public API); use max_n_exact=1_000_000 sketch as workaround

Confirmed Experimental Results

Experiment Result Criterion
Accuracy vs R (same data, n=100K) 0.003% < 0.01% ✓
Validation (m≥2, Marron-Wand) R²=0.91, MAE=0.07, ρ=0.89 R²≥0.85 ✓
Speedup vs scipy (CUDA T4, n=8192) 10.5× ≥3× ✓
GAN mode preservation h_crit=1.232 >> 0.3 h_crit>0.3 ✓
Anomaly AUC (KDDCup99) DCB=0.9982 vs IF=0.9867 DCB≥IF ✓
GPU speedup (P100, n=50K) 116× vs CPU
GPU speedup (P100, n=100K) 43× vs CPU

Changelog

v0.1.6 (2026-05-30)

  • TrainingLayer: warm-start bracket caching (1.82× CPU speedup in training loops)
  • direct_mode_count_batch: direct-KDE path for n ≤ 25K (zero histogram bias; 80–96× GPU speedup)
  • Compile-ready trisection: tensor lo/hi, no .item() inside loop, fixed 16-round unroll
  • mode_count_from_C_batch returns Tensor(B,) (was list[int]) — enables torch.compile tracing

v0.1.5 (2026-05-29)

  • Richardson extrapolation on h_crit scalar (30% accuracy gain, G=16384+8192)
  • alloc/sync hygiene: removed nonzero_mask host sync (4.2× faster at n=10M)
  • Batched trisection bisection (one irfft dispatch per round)
  • Eliminated duplicate O(n) histogram in _refine_hcrit (C_external reuse)

v0.1.4 (2026-05-29)

  • FFT histogram path: C hoisted out of bisection loop (Worker 1)
  • Device-native histogram: CUDA histc, MPS scatter_add_, CPU bucketize+bincount
  • float32 FFT default; pad_factor 4→2 (halves irfft size)
  • Adaptive bisection early-exit

v0.1.1 (2026-05-29)

  • MPS histc OOM bug fixed (bucketize+bincount)
  • Sketch API: max_n_exact=1M, sketch_size=500K
  • Domain consistency and bias warning fixes

Repository Structure

dcb/
  layer.py         DCBLayer nn.Module + DCBFunction autograd
  solver.py        IFT root-finder, trisection bisection, Richardson pass
  fft_kde.py       FFT mode counter, direct_mode_count_batch, precompute_fft
  training.py      TrainingLayer with warm-start and compile support
  kde.py           Direct KDE derivatives (IFT backward path)
  utils.py         Grid, Silverman bandwidth, sg() stabiliser
experiments/       Reproduction scripts for all benchmarks and paper figures
tests/             Unit tests (45 passed, 1 xfailed)

License

Apache 2.0 — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffcb-0.1.7.tar.gz (64.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffcb-0.1.7-py3-none-any.whl (50.1 kB view details)

Uploaded Python 3

File details

Details for the file diffcb-0.1.7.tar.gz.

File metadata

  • Download URL: diffcb-0.1.7.tar.gz
  • Upload date:
  • Size: 64.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for diffcb-0.1.7.tar.gz
Algorithm Hash digest
SHA256 fd7925a1d4321213625acc3725cbe46e8c26749ad55e6e542f0921d4ce97154f
MD5 aed2ff0ee43432b79d260cbf5283b122
BLAKE2b-256 9bb1a2b866a4af09ba827c77d09a123ccce949a4d611783d2ba601c31957c364

See more details on using hashes here.

File details

Details for the file diffcb-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: diffcb-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 50.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for diffcb-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 94d9815230f4b86615d811a996bb2990a9fd0a624d61ffac100f75ddde816325
MD5 1e4316195b069e540db2d78fe32ef412
BLAKE2b-256 225e6f270e65194ec705859438bd04eccdc243a679964e2ab28a858bf702827c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page