Skip to main content

Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.

Project description

DCB — Differentiable Critical Bandwidth

PyPI License: Apache 2.0 Python 3.9+

A PyTorch package that makes Silverman's critical bandwidth test (1981) fully differentiable, enabling end-to-end gradient-based optimisation over the modal structure of continuous distributions.

Overview

h_crit is the minimum KDE bandwidth at which a distribution appears unimodal — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable step with a smooth surrogate, then uses the Implicit Function Theorem (IFT) to compute exact gradients through the root-finding step at O(1) memory cost.

import torch
from dcb import DCBLayer, TrainingLayer

X = torch.randn(10_000, requires_grad=True)  # 1D samples, any n from 5K to 1B
layer = DCBLayer()
h_crit = layer(X)       # differentiable scalar
h_crit.backward()       # exact IFT gradients

# For repeated training-loop use with warm-start bracket caching:
layer = TrainingLayer(warm_start=True)
for batch in dataloader:
    h = layer(batch)    # 1.8× faster after first call on CPU; ~10× on CUDA with compile=True

Installation

pip install diffcb

Or from source:

git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
cd differentiable-critical-bandwidth
pip install -e ".[dev]"

Accuracy vs R's bw.crit

Validated against R's multimode::bw.crit(data, mod0=1) (Hall & York 2001). Same-sample protocol (identical data fed to both Python and R):

n DCB error vs R Notes
5K–25K < 0.005% Direct-KDE path, zero histogram bias
100K 0.003% FFT histogram path, G=16384
1M 0.003% FFT path
10M 0.003% FFT path
100M+ < 0.01% Histogram-dominated; sketch available

Independent-sample error (~0.2–0.5%) reflects natural sampling variability (two RNGs), not algorithmic error. The 0.003% algorithmic error sits below R's own ~0.001% numerical noise floor.

Hardware Performance (v0.1.6)

n CPU (Apple M) MPS P100 GPU
10K 2,300 ms 1,400 ms 107 ms
50K 2,900 ms 1,700 ms 167 ms
100K 265 ms 248 ms 35 ms
1M 269 ms 189 ms 36 ms
10M 544 ms 44 ms

P100 speedup: 43–116× vs CPU. Peak 116× at n=50K (direct-KDE GPU parallelism).

Cumulative speedup vs v0.1.4 on CPU: 1.1× (100K), 1.7× (1M), 4.2× (10M).

API Reference

DCBLayer

DCBLayer(
    target_modes=1,           # target number of modes (default 1)
    use_fft=True,             # FFT path for n > 50K (default True)
    max_n_exact=None,         # sketch above this n (None = always exact)
    G_min=16384,              # minimum FFT histogram bins (accuracy ↑ with G)
    use_richardson="auto",    # Richardson on CPU, off on GPU (30% accuracy gain on CPU)
    direct_n_max=25_000,      # direct-KDE active only when forward_path='auto'/'direct'
    direct_M=2048,            # direct-KDE evaluation grid size
    forward_path='smooth',    # 'smooth' (default, strictly differentiable) |
                              # 'auto' (direct-KDE at n≤25K, surrogate gradient) |
                              # 'direct' (force direct-KDE, accuracy benchmarks)
    safe_backward=False,      # clamp IFT denominator near bifurcations
)

TrainingLayer (for ML training loops)

from dcb import TrainingLayer

layer = TrainingLayer(
    warm_start=True,    # cache h_prev; init bracket to [0.95h, 1.05h] → 1.8× CPU speedup
    compile=False,      # torch.compile opt-in (requires float32, Python ≤ 3.11 on CUDA)
    warm_margin=0.05,   # bracket half-width around cached h_crit
    **dcb_kwargs,       # any DCBLayer parameter
)
layer.reset_cache()     # call on distribution shift

Direct-KDE path (n ≤ 25K)

For small samples, DCB evaluates f′_h directly without histogram binning (O(n·M) per evaluation, zero binning bias). This is 3–4× slower on CPU but 80–96× faster than CPU on GPU.

# Force direct-KDE for all n (accuracy benchmark):
layer = DCBLayer(direct_n_max=float('inf'))

# Disable direct-KDE (speed benchmark):
layer = DCBLayer(direct_n_max=0)

Richardson extrapolation

By default (use_richardson=True), DCB runs a second bisection at G/2=8192 and combines: h̃ = (4·ĥ(G) − ĥ(G/2)) / 3, reducing error ~30%. On GPU this adds 38% overhead with <0.01% accuracy gain — consider use_richardson=False for GPU training loops.

Known Limitations

  • compile=True on MPS: blocked by float64 in _refine_hcrit fallback (fix in v0.1.7)
  • compile=True on CUDA with Python 3.12: requires torch ≥ 2.4 or Python ≤ 3.11
  • gradcheck: passes with the default forward_path='smooth'; the default is strictly differentiable at all n. Opt into forward_path='auto' only for forward-only accuracy benchmarks (surrogate gradient at n≤25K)
  • n > 100M: requires streaming histogram (not yet public API); use max_n_exact=1_000_000 sketch as workaround

Confirmed Experimental Results

Experiment Result Criterion
Accuracy vs R (same data, n=100K) 0.003% < 0.01% ✓
Validation (m≥2, Marron-Wand) R²=0.91, MAE=0.07, ρ=0.89 R²≥0.85 ✓
Speedup vs scipy (CUDA T4, n=8192) 10.5× ≥3× ✓
GAN mode preservation h_crit=1.232 >> 0.3 h_crit>0.3 ✓
Anomaly AUC (KDDCup99) DCB=0.9982 vs IF=0.9867 DCB≥IF ✓
GPU speedup (P100, n=50K) 116× vs CPU
GPU speedup (P100, n=100K) 43× vs CPU

Changelog

v0.1.6 (2026-05-30)

  • TrainingLayer: warm-start bracket caching (1.82× CPU speedup in training loops)
  • direct_mode_count_batch: direct-KDE path for n ≤ 25K (zero histogram bias; 80–96× GPU speedup)
  • Compile-ready trisection: tensor lo/hi, no .item() inside loop, fixed 16-round unroll
  • mode_count_from_C_batch returns Tensor(B,) (was list[int]) — enables torch.compile tracing

v0.1.5 (2026-05-29)

  • Richardson extrapolation on h_crit scalar (30% accuracy gain, G=16384+8192)
  • alloc/sync hygiene: removed nonzero_mask host sync (4.2× faster at n=10M)
  • Batched trisection bisection (one irfft dispatch per round)
  • Eliminated duplicate O(n) histogram in _refine_hcrit (C_external reuse)

v0.1.4 (2026-05-29)

  • FFT histogram path: C hoisted out of bisection loop (Worker 1)
  • Device-native histogram: CUDA histc, MPS scatter_add_, CPU bucketize+bincount
  • float32 FFT default; pad_factor 4→2 (halves irfft size)
  • Adaptive bisection early-exit

v0.1.1 (2026-05-29)

  • MPS histc OOM bug fixed (bucketize+bincount)
  • Sketch API: max_n_exact=1M, sketch_size=500K
  • Domain consistency and bias warning fixes

Repository Structure

dcb/
  layer.py         DCBLayer nn.Module + DCBFunction autograd
  solver.py        IFT root-finder, trisection bisection, Richardson pass
  fft_kde.py       FFT mode counter, direct_mode_count_batch, precompute_fft
  training.py      TrainingLayer with warm-start and compile support
  kde.py           Direct KDE derivatives (IFT backward path)
  utils.py         Grid, Silverman bandwidth, sg() stabiliser
experiments/       Reproduction scripts for all benchmarks and paper figures
tests/             Unit tests (45 passed, 1 xfailed)

License

Apache 2.0 — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffcb-0.1.10.tar.gz (68.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffcb-0.1.10-py3-none-any.whl (51.1 kB view details)

Uploaded Python 3

File details

Details for the file diffcb-0.1.10.tar.gz.

File metadata

  • Download URL: diffcb-0.1.10.tar.gz
  • Upload date:
  • Size: 68.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for diffcb-0.1.10.tar.gz
Algorithm Hash digest
SHA256 8c556ff6aefb5fece45a002b54f9457863b4676ec58016dde9dd5250a04d09b2
MD5 19bcac492920fd96aab9f2951e8ef0f5
BLAKE2b-256 e8b78bf4bcb8194ea606bee1bd507739180f3e71852ef91c167b2af06b42f98a

See more details on using hashes here.

File details

Details for the file diffcb-0.1.10-py3-none-any.whl.

File metadata

  • Download URL: diffcb-0.1.10-py3-none-any.whl
  • Upload date:
  • Size: 51.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for diffcb-0.1.10-py3-none-any.whl
Algorithm Hash digest
SHA256 ffc2be338ac76478d53960a0ef9998bd7b7949b091002dba131e808d9c87adde
MD5 201548cbc7c940dd16f8cb383234b7ee
BLAKE2b-256 73e4abf44f3ca02700dd9b10db6ac8c86e385b47be7e9e6dda7fed7da00a4cf5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page