Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
Project description
DCB — Differentiable Critical Bandwidth
A PyTorch package that makes Silverman's critical bandwidth test (1981) fully differentiable, enabling end-to-end gradient-based optimisation over the modal structure of continuous distributions.
Overview
h_crit is the minimum KDE bandwidth at which a distribution appears unimodal — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable step with a smooth surrogate, then uses the Implicit Function Theorem (IFT) to compute exact gradients through the root-finding step at O(1) memory cost.
import torch
from dcb import DCBLayer, TrainingLayer
X = torch.randn(10_000, requires_grad=True) # 1D samples, any n from 5K to 1B
layer = DCBLayer()
h_crit = layer(X) # differentiable scalar
h_crit.backward() # exact IFT gradients
# For repeated training-loop use with warm-start bracket caching:
layer = TrainingLayer(warm_start=True)
for batch in dataloader:
h = layer(batch) # 1.8× faster after first call on CPU; ~10× on CUDA with compile=True
Installation
pip install diffcb
Or from source:
git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
cd differentiable-critical-bandwidth
pip install -e ".[dev]"
Accuracy vs R's bw.crit
Validated against R's multimode::bw.crit(data, mod0=1) (Hall & York 2001).
Same-sample protocol (identical data fed to both Python and R):
| n | DCB error vs R | Notes |
|---|---|---|
| 5K–25K | < 0.005% | Direct-KDE path, zero histogram bias |
| 100K | 0.003% | FFT histogram path, G=16384 |
| 1M | 0.003% | FFT path |
| 10M | 0.003% | FFT path |
| 100M+ | < 0.01% | Histogram-dominated; sketch available |
Independent-sample error (~0.2–0.5%) reflects natural sampling variability (two RNGs), not algorithmic error. The 0.003% algorithmic error sits below R's own ~0.001% numerical noise floor.
Hardware Performance (v0.1.6)
| n | CPU (Apple M) | MPS | P100 GPU |
|---|---|---|---|
| 10K | 2,300 ms | 1,400 ms | 107 ms |
| 50K | 2,900 ms | 1,700 ms | 167 ms |
| 100K | 265 ms | 248 ms | 35 ms |
| 1M | 269 ms | 189 ms | 36 ms |
| 10M | 544 ms | — | 44 ms |
P100 speedup: 43–116× vs CPU. Peak 116× at n=50K (direct-KDE GPU parallelism).
Cumulative speedup vs v0.1.4 on CPU: 1.1× (100K), 1.7× (1M), 4.2× (10M).
API Reference
DCBLayer
DCBLayer(
target_modes=1, # target number of modes (default 1)
use_fft=True, # FFT path for n > 50K (default True)
max_n_exact=None, # sketch above this n (None = always exact)
G_min=16384, # minimum FFT histogram bins (accuracy ↑ with G)
use_richardson=True, # Richardson extrapolation on h_crit (30% accuracy gain)
direct_n_max=25_000, # use direct-KDE (no histogram) for n ≤ this
direct_M=2048, # direct-KDE evaluation grid size
use_compile=False, # infrastructure flag; use TrainingLayer for compile
safe_backward=False, # clamp IFT denominator near bifurcations
)
TrainingLayer (for ML training loops)
from dcb import TrainingLayer
layer = TrainingLayer(
warm_start=True, # cache h_prev; init bracket to [0.95h, 1.05h] → 1.8× CPU speedup
compile=False, # torch.compile opt-in (requires float32, Python ≤ 3.11 on CUDA)
warm_margin=0.05, # bracket half-width around cached h_crit
**dcb_kwargs, # any DCBLayer parameter
)
layer.reset_cache() # call on distribution shift
Direct-KDE path (n ≤ 25K)
For small samples, DCB evaluates f′_h directly without histogram binning (O(n·M) per evaluation, zero binning bias). This is 3–4× slower on CPU but 80–96× faster than CPU on GPU.
# Force direct-KDE for all n (accuracy benchmark):
layer = DCBLayer(direct_n_max=float('inf'))
# Disable direct-KDE (speed benchmark):
layer = DCBLayer(direct_n_max=0)
Richardson extrapolation
By default (use_richardson=True), DCB runs a second bisection at G/2=8192 and combines:
h̃ = (4·ĥ(G) − ĥ(G/2)) / 3, reducing error ~30%. On GPU this adds 38% overhead with
<0.01% accuracy gain — consider use_richardson=False for GPU training loops.
Known Limitations
compile=Trueon MPS: blocked by float64 in_refine_hcritfallback (fix in v0.1.7)compile=Trueon CUDA with Python 3.12: requires torch ≥ 2.4 or Python ≤ 3.11gradcheckwith direct-KDE forward: forward (n ≤ 25K) uses exact KDE; backward uses smooth IFT surrogate — gradcheck will fail by design; useDCBLayer(direct_n_max=0)for gradcheck- n > 100M: requires streaming histogram (not yet public API); use
max_n_exact=1_000_000sketch as workaround
Confirmed Experimental Results
| Experiment | Result | Criterion |
|---|---|---|
| Accuracy vs R (same data, n=100K) | 0.003% | < 0.01% ✓ |
| Validation (m≥2, Marron-Wand) | R²=0.91, MAE=0.07, ρ=0.89 | R²≥0.85 ✓ |
| Speedup vs scipy (CUDA T4, n=8192) | 10.5× | ≥3× ✓ |
| GAN mode preservation | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
| Anomaly AUC (KDDCup99) | DCB=0.9982 vs IF=0.9867 | DCB≥IF ✓ |
| GPU speedup (P100, n=50K) | 116× vs CPU | — |
| GPU speedup (P100, n=100K) | 43× vs CPU | — |
Changelog
v0.1.6 (2026-05-30)
TrainingLayer: warm-start bracket caching (1.82× CPU speedup in training loops)direct_mode_count_batch: direct-KDE path for n ≤ 25K (zero histogram bias; 80–96× GPU speedup)- Compile-ready trisection: tensor lo/hi, no
.item()inside loop, fixed 16-round unroll mode_count_from_C_batchreturnsTensor(B,)(waslist[int]) — enables torch.compile tracing
v0.1.5 (2026-05-29)
- Richardson extrapolation on h_crit scalar (30% accuracy gain, G=16384+8192)
- alloc/sync hygiene: removed
nonzero_maskhost sync (4.2× faster at n=10M) - Batched trisection bisection (one irfft dispatch per round)
- Eliminated duplicate O(n) histogram in
_refine_hcrit(C_external reuse)
v0.1.4 (2026-05-29)
- FFT histogram path: C hoisted out of bisection loop (Worker 1)
- Device-native histogram: CUDA histc, MPS scatter_add_, CPU bucketize+bincount
- float32 FFT default; pad_factor 4→2 (halves irfft size)
- Adaptive bisection early-exit
v0.1.1 (2026-05-29)
- MPS histc OOM bug fixed (bucketize+bincount)
- Sketch API: max_n_exact=1M, sketch_size=500K
- Domain consistency and bias warning fixes
Repository Structure
dcb/
layer.py DCBLayer nn.Module + DCBFunction autograd
solver.py IFT root-finder, trisection bisection, Richardson pass
fft_kde.py FFT mode counter, direct_mode_count_batch, precompute_fft
training.py TrainingLayer with warm-start and compile support
kde.py Direct KDE derivatives (IFT backward path)
utils.py Grid, Silverman bandwidth, sg() stabiliser
experiments/ Reproduction scripts for all benchmarks and paper figures
tests/ Unit tests (45 passed, 1 xfailed)
License
Apache 2.0 — see LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file diffcb-0.1.7.tar.gz.
File metadata
- Download URL: diffcb-0.1.7.tar.gz
- Upload date:
- Size: 64.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fd7925a1d4321213625acc3725cbe46e8c26749ad55e6e542f0921d4ce97154f
|
|
| MD5 |
aed2ff0ee43432b79d260cbf5283b122
|
|
| BLAKE2b-256 |
9bb1a2b866a4af09ba827c77d09a123ccce949a4d611783d2ba601c31957c364
|
File details
Details for the file diffcb-0.1.7-py3-none-any.whl.
File metadata
- Download URL: diffcb-0.1.7-py3-none-any.whl
- Upload date:
- Size: 50.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
94d9815230f4b86615d811a996bb2990a9fd0a624d61ffac100f75ddde816325
|
|
| MD5 |
1e4316195b069e540db2d78fe32ef412
|
|
| BLAKE2b-256 |
225e6f270e65194ec705859438bd04eccdc243a679964e2ab28a858bf702827c
|