Skip to main content

A differentiable implementation of kernel density estimation in PyTorch

Project description

TorchKDE :fire:

Python Version PyTorch Version Tests DOI

A differentiable implementation of kernel density estimation in PyTorch by Klaus-Rudolf Kladny.

$$\hat{f}(x) = \frac{1}{|H|^{\frac{1}{2}} n} \sum_{i=1}^n K \left( H^{-\frac{1}{2}} \left( x - x_i \right) \right)$$

Installation Instructions

Clone the repository, cd into the root directory and run

pip install .

Now you are ready to go! If you would also like to run the code in the Jupyter notebooks or contribute to this package, please also install the packages in the requirements.txt (inside of an environment):

pip install -r requirements.txt

What's included?

Kernel Density Estimation

The KernelDensity class supports the same operations as the KernelDensity class in scikit-learn, but implemented in PyTorch and differentiable with respect to input data. Here is a little taste:

from torchkde import KernelDensity
import torch

multivariate_normal = torch.distributions.MultivariateNormal(torch.ones(2), torch.eye(2))
X = multivariate_normal.sample((1000,)) # create data
X.requires_grad = True # enable differentiation
kde = KernelDensity(bandwidth=1.0, kernel='gaussian') # create kde object with isotropic bandwidth matrix
_ = kde.fit(X) # fit kde to data

X_new = multivariate_normal.sample((100,)) # create new data 
logprob = kde.score_samples(X_new)

logprob.grad_fn # is not None

You may also check out demo_kde.ipynb for a simple demo on the Bart Simpson distribution, which yields the following density estimate:

Tophat Kernel Approximation

The Tophat kernel is not differentiable at two points and has zero derivative everywhere else. Thus, we provide a differentiable approximation via a generalized Gaussian (see e.g. Pascal et al. for reference):

$$K^{\text{tophat}}(x; \beta) = \frac{\beta \Gamma \left( \frac{p}{2} \right) }{\pi^{\frac{p}{2}} \Gamma \left( \frac{p}{2\beta} \right) 2^{\frac{p}{2\beta}}} \text{exp} \left( - \frac{| x |_2^{2\beta}}{2} \right),$$

where $p$ is the dimensionality of $x$. Based on this kernel, we can approximate the Tophat kernel for large values of $\beta$, as shown in the following 1-dimensional example:

We note that for $\beta = 1$, this approximation corresponds to a Gaussian kernel. Also, while the approximation becomes better for large values of $\beta$, its gradients with respect to the input also become larger. This is a tradeoff that must be balanced when using this kernel.

Supported Settings

The current implementation provides the following functionality:

Feature Supported Values
Kernels Gaussian, Epanechnikov, Exponential, Tophat Approximation
Tree Algorithms Standard
Bandwidths Float (Isotropic bandwidth matrix), Scott, Silverman

Got an Extension? Create a Pull Request!

In case you do not know how to do that, here are the necessary steps:

  1. Fork the repo
  2. Create your feature branch (git checkout -b cool_tree_algorithm)
  3. Run the unit tests (python -m tests.test_kde) and only proceed if the script outputs "OK".
  4. Commit your changes (git commit -am 'Add cool tree algorithm')
  5. Push to the branch (git push origin cool_tree_algorithm)
  6. Open a Pull Request

Issues?

If you discover a bug or do not understand something, please create an issue or let me know directly at kkladny [at] tuebingen [dot] mpg [dot] de! I am also happy to take requests for implementing specific functionalities.

"In God we trust. All others must bring data."

— W. Edwards Deming

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_kde-0.1.0.tar.gz (9.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_kde-0.1.0-py3-none-any.whl (9.4 kB view details)

Uploaded Python 3

File details

Details for the file torch_kde-0.1.0.tar.gz.

File metadata

  • Download URL: torch_kde-0.1.0.tar.gz
  • Upload date:
  • Size: 9.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.11

File hashes

Hashes for torch_kde-0.1.0.tar.gz
Algorithm Hash digest
SHA256 ce66a039beec9a793fbb455ca090d0a5345779615632a0a7cfd4ad048164150b
MD5 e1343b64134ee79fb0161df56bf47f99
BLAKE2b-256 7914ebe5e5a24ef743263cd4703f7fb9fa6932eaeaa18fb03a6f214f03148faa

See more details on using hashes here.

File details

Details for the file torch_kde-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torch_kde-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 9.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.11

File hashes

Hashes for torch_kde-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 23dec8f68479d5f4c2273d16ef7a70c18ea403eeb9b1ea759870aff40b2627e7
MD5 0548b9f825ba342858e8dfde7cd49400
BLAKE2b-256 2ff503835440ff29157e8d3a2bcc803d3a7a4014a8c3b0599910d09c97762ecd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page