Skip to main content

Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.

Project description

DCB — Differentiable Critical Bandwidth

PyPI License: Apache 2.0 Python 3.9+

A PyTorch package that makes Silverman's critical bandwidth test (1981) fully differentiable, enabling end-to-end gradient-based optimisation over the modal structure of continuous distributions.

Overview

h_crit is the minimum KDE bandwidth at which a distribution appears unimodal — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable step with a smooth surrogate, then uses the Implicit Function Theorem (IFT) to compute exact gradients through the root-finding step at O(1) memory cost.

import torch
from dcb import DCBLayer, TrainingLayer

X = torch.randn(100_000, requires_grad=True)  # 1D samples, any n from 1K to 1B
layer = DCBLayer()
h_crit = layer(X)       # differentiable scalar
h_crit.backward()       # exact IFT gradients

# For repeated training-loop use with warm-start bracket caching:
layer = TrainingLayer(warm_start=True)
for batch in dataloader:
    h = layer(batch)    # 1.8× faster after first call on CPU; ~10× on CUDA with compile=True

Installation

pip install diffcb

Or from source:

git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
cd differentiable-critical-bandwidth
pip install -e ".[dev]"

Accuracy vs R's bw.crit

Validated against R's multimode::bw.crit(data, mod0=1) (Hall & York 2001). Same-sample protocol (identical data fed to both Python and R, 3–5 seeds per n):

n Bimodal mean% Bimodal max% Gaussian mean% Gaussian max%
1K 0.002% 0.005% 0.028% 0.053%
5K 0.003% 0.006% 0.107% 0.160%
10K 0.001% 0.005% 0.091% 0.256%
100K 0.004% 0.005% 0.092% 0.400%
1M 0.001% 0.001% 0.045% 0.066%
100M¹ 0.001% 0.001% 0.045% 0.066%
1B¹ 0.001% 0.001% 0.045% 0.066%

¹ n > 1M uses the sketch API (default max_n_exact=1_000_000); accuracy is measured on the 1M sketch vs R on the same sketch. Bimodal data interleaved so every slice is 50/50.

Distribution note: Bimodal = 0.5·N(−2,1) + 0.5·N(+2,1), h_crit ≈ 1.7. Gaussian = N(0,1), h_crit ~ n^{−1/5} (shrinks with n). Independent-sample error (~0.2–0.5%) reflects sampling variability, not algorithmic error.

Hardware Performance (v0.1.10, CPU Apple M-series)

n t_median (ms) n/s fft_norm_cost¹
1K 443 2.3K 7,112
10K 699 14K 11,222
100K 265 377K 4,258
1M 471 2.1M 7,556
10M (full) 1,146 8.7M 18,387
10M (sketch) 449 22M 7,208
100M (sketch) 504 198M 8,094
1B (sketch) 460 2.17B 7,391

¹ FFT-normalised cost = t_dcb / t_fft_one, where t_fft_one = median time for one rfft(16384, float32) on the same device (0.062 ms, Apple M). Hardware-agnostic metric for cross-chip comparison.

GPU (Kaggle P100, Round 24): 43–116× speedup vs CPU. Peak 116× at n=50K (direct-KDE parallelism).

API Reference

DCBLayer

DCBLayer(
    target_modes=1,            # target number of modes (default 1 = unimodality test)
    use_fft=True,              # FFT histogram path for n > 50K (default True)
    max_n_exact=1_000_000,     # sketch above this n; None = always exact
    sketch_size=500_000,       # target size for the random sketch subsample
    G_min=16384,               # minimum FFT histogram bins; scales with n^0.4 internally
    use_richardson=True,       # Richardson extrapolation (~30% accuracy gain on CPU)
    forward_path='smooth',     # 'smooth' (default, strictly differentiable, gradcheck passes)
                               # 'auto'   (direct-KDE at n≤25K; surrogate gradient)
                               # 'direct' (force direct-KDE; accuracy benchmarks only)
    fft_dtype=torch.float32,   # histogram FFT dtype for speed; mode counting always uses
                               # float64 internally regardless of this setting
    direct_n_max=25_000,       # direct-KDE threshold (only active in 'auto'/'direct')
    direct_M=2048,             # direct-KDE evaluation grid size
)

forward_path='smooth' (default): Both forward and backward use the same smooth M̃ surrogate at all n. torch.autograd.gradcheck passes. Gradients are exact IFT derivatives of the computed h_crit.

forward_path='auto': Uses direct-KDE for the forward pass at n ≤ 25K (zero binning bias) but the IFT backward always uses M̃. Forward and backward use different implicit functions at n ≤ 25K — gradcheck fails. Use only for forward-only benchmarks.

TrainingLayer (for ML training loops)

from dcb import TrainingLayer

layer = TrainingLayer(
    warm_start=True,    # cache h_prev; narrow bracket to [0.95h, 1.05h] → 1.8× speedup
    compile=False,      # torch.compile opt-in (requires float32, Python ≤ 3.11 on CUDA)
    warm_margin=0.05,
    **dcb_kwargs,       # any DCBLayer parameter
)
layer.reset_cache()     # call when distribution shifts significantly

Gradient quality check

from torch.autograd import gradcheck
layer = DCBLayer(forward_path='smooth')
X = torch.randn(200, dtype=torch.float64, requires_grad=True)
assert gradcheck(layer, (X,), eps=1e-5, atol=1e-4)  # passes with default settings

Richardson extrapolation

By default (use_richardson=True), DCB runs a second bisection at G/2 bins and applies h̃ = (4·ĥ(G) − ĥ(G/2)) / 3, reducing error ~30% on CPU. Both passes use float64 histogram FFTs for mode counting. On GPU this adds ~38% overhead with minimal gain; use use_richardson=False for GPU training loops.

Numerical Design Notes

Float64 mode counting: The bisection trisection and Richardson loops use float64 histogram FFTs (precompute_fft(..., fft_dtype=float64)) for mode counting, regardless of fft_dtype. Float32 histogram FFT errors (~10⁻⁴ relative) create spurious sign changes in f′ near small h_crit (unimodal distributions, h_crit < ~0.5) that cannot be filtered without suppressing genuine small lobes. Float64 reduces this floor to ~10⁻¹⁶, which the 10⁻¹² relative threshold cleanly removes. Float32 is still used for _refine_hcrit refinement evals and speed-sensitive paths.

n-aware G_min: Effective grid size = min(G_min × (n/100K)^0.4, 262144), keeping h_crit / bin_width stable as h_crit shrinks with n (Gaussian regime: h_crit ~ n^{−1/5}).

Known Limitations

  • 1D only. No multivariate extension implemented.
  • Gaussian max error ~0.40% for ~5–10% of seeds at n=100K. Arises when the disappearing lobe height approaches the 10⁻¹² noise threshold just below h_crit, causing the bisection to converge slightly early and _refine_hcrit to select a neighbouring lobe. Bimodal data (large h_crit) is unaffected.
  • MPS float64. The float64 histogram precompute falls back to CPU on MPS devices; MPS does not support float64 FFT. Handled automatically.
  • Independent-sample variability (~0.5%). Two independent random draws from the same distribution have h_crit differing by ~0.2–0.5% at n=100K — irreducible MC variance, not algorithmic error.
  • torch.compile blockers. MPS: blocked by float64 fallback. CUDA: requires Python ≤ 3.11 or torch ≥ 2.4.

Confirmed Experimental Results

Experiment Result Criterion
Accuracy vs R — bimodal, same-sample, n=100K 0.004% mean < 0.01% ✓
Accuracy vs R — Gaussian, same-sample, n=100K 0.092% mean < 0.5% ✓
Accuracy vs R — bimodal, same-sample, n=1B (sketch) 0.001% mean < 0.01% ✓
Validation (m≥2, Marron-Wand mixtures) R²=0.91, MAE=0.07, ρ=0.89 R²≥0.85 ✓
Speedup vs scipy (CUDA T4, n=8192) 10.5× ≥3× ✓
GAN mode preservation h_crit=1.232 >> 0.3 h_crit > 0.3 ✓
Anomaly AUC (KDDCup99) DCB=0.9982 vs IF=0.9867 DCB ≥ IF ✓
GPU speedup (P100, n=50K) 116× vs CPU
GPU speedup (P100, n=100K) 43× vs CPU

Version History

Version Key change
0.1.10 n-aware G_min; float64 in Richardson path and _refine_hcrit; square-root refinement model
0.1.9 Critical fix: float64 histogram for mode counting (float32 caused 100–500% error on Gaussian/unimodal data)
0.1.8 forward_path='smooth' as default — strictly differentiable, gradcheck passes
0.1.7 Round 25: float32 cleanup, torch.compile infrastructure, forward_batched API
0.1.6 Round 24: TrainingLayer, direct-KDE path (n≤25K), compile-ready trisection
0.1.4–0.1.5 Round 22–23: 4.2× speedup, Richardson extrapolation
0.1.1 Round 20–21: MPS fix, sketch API, G_min 16384, 0.004% bimodal accuracy

Repository Structure

dcb/
  layer.py       DCBLayer nn.Module + DCBFunction autograd
  solver.py      IFT root-finder, trisection bisection, Richardson pass
  fft_kde.py     FFT mode counter, direct_mode_count_batch, precompute_fft
  training.py    TrainingLayer with warm-start and compile support
  kde.py         Direct KDE derivatives (IFT backward path)
  utils.py       Grid, Silverman bandwidth, sg() stabiliser
tests/           Unit tests (49 passed, 1 xfailed)

Citation

@software{zhang2026diffcb,
  author = {Zhang, Ruiyu},
  title  = {diffcb: Differentiable Critical Bandwidth for PyTorch},
  year   = {2026},
  url    = {https://github.com/ryZhangHason/differentiable-critical-bandwidth},
  note   = {PyPI: diffcb v0.1.10}
}

License

Apache 2.0 — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffcb-0.1.11-py3-none-any.whl (53.3 kB view details)

Uploaded Python 3

File details

Details for the file diffcb-0.1.11-py3-none-any.whl.

File metadata

  • Download URL: diffcb-0.1.11-py3-none-any.whl
  • Upload date:
  • Size: 53.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for diffcb-0.1.11-py3-none-any.whl
Algorithm Hash digest
SHA256 6baedd1e499791e23eff52626f7991ab71988784f9accf61d57aa5ce3505e172
MD5 dcf9636d9341db6a6f6152a926e8a1a9
BLAKE2b-256 b60c4b36e6c8ebf7ed939d767d34434381eb384b3ccfdc16a5ecfcea1765fcb7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page