A differentiable implementation of kernel density estimation in PyTorch
Project description
TorchKDE :fire:
A differentiable implementation of kernel density estimation in PyTorch by Klaus-Rudolf Kladny.
$$\hat{f}(x) = \frac{1}{|H|^{\frac{1}{2}} n} \sum_{i=1}^n K \left( H^{-\frac{1}{2}} \left( x - x_i \right) \right)$$
Installation Instructions
Clone the repository, cd into the root directory and run
pip install .
Now you are ready to go! If you would also like to run the code in the Jupyter notebooks or contribute to this package, please also install the packages in the requirements.txt (inside of an environment):
pip install -r requirements.txt
What's included?
Kernel Density Estimation
The KernelDensity class supports the same operations as the KernelDensity class in scikit-learn, but implemented in PyTorch and differentiable with respect to input data. Here is a little taste:
from torchkde import KernelDensity
import torch
multivariate_normal = torch.distributions.MultivariateNormal(torch.ones(2), torch.eye(2))
X = multivariate_normal.sample((1000,)) # create data
X.requires_grad = True # enable differentiation
kde = KernelDensity(bandwidth=1.0, kernel='gaussian') # create kde object with isotropic bandwidth matrix
_ = kde.fit(X) # fit kde to data
X_new = multivariate_normal.sample((100,)) # create new data
logprob = kde.score_samples(X_new)
logprob.grad_fn # is not None
You may also check out demo_kde.ipynb for a simple demo on the Bart Simpson distribution, which yields the following density estimate:
Tophat Kernel Approximation
The Tophat kernel is not differentiable at two points and has zero derivative everywhere else. Thus, we provide a differentiable approximation via a generalized Gaussian (see e.g. Pascal et al. for reference):
$$K^{\text{tophat}}(x; \beta) = \frac{\beta \Gamma \left( \frac{p}{2} \right) }{\pi^{\frac{p}{2}} \Gamma \left( \frac{p}{2\beta} \right) 2^{\frac{p}{2\beta}}} \text{exp} \left( - \frac{| x |_2^{2\beta}}{2} \right),$$
where $p$ is the dimensionality of $x$. Based on this kernel, we can approximate the Tophat kernel for large values of $\beta$, as shown in the following 1-dimensional example:
We note that for $\beta = 1$, this approximation corresponds to a Gaussian kernel. Also, while the approximation becomes better for large values of $\beta$, its gradients with respect to the input also become larger. This is a tradeoff that must be balanced when using this kernel.
Supported Settings
The current implementation provides the following functionality:
| Feature | Supported Values |
|---|---|
| Kernels | Gaussian, Epanechnikov, Exponential, Tophat Approximation |
| Tree Algorithms | Standard |
| Bandwidths | Float (Isotropic bandwidth matrix), Scott, Silverman |
Got an Extension? Create a Pull Request!
In case you do not know how to do that, here are the necessary steps:
- Fork the repo
- Create your feature branch (
git checkout -b cool_tree_algorithm) - Run the unit tests (
python -m tests.test_kde) and only proceed if the script outputs "OK". - Commit your changes (
git commit -am 'Add cool tree algorithm') - Push to the branch (
git push origin cool_tree_algorithm) - Open a Pull Request
Issues?
If you discover a bug or do not understand something, please create an issue or let me know directly at kkladny [at] tuebingen [dot] mpg [dot] de! I am also happy to take requests for implementing specific functionalities.
"In God we trust. All others must bring data."
— W. Edwards Deming
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_kde-0.1.0.tar.gz.
File metadata
- Download URL: torch_kde-0.1.0.tar.gz
- Upload date:
- Size: 9.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ce66a039beec9a793fbb455ca090d0a5345779615632a0a7cfd4ad048164150b
|
|
| MD5 |
e1343b64134ee79fb0161df56bf47f99
|
|
| BLAKE2b-256 |
7914ebe5e5a24ef743263cd4703f7fb9fa6932eaeaa18fb03a6f214f03148faa
|
File details
Details for the file torch_kde-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_kde-0.1.0-py3-none-any.whl
- Upload date:
- Size: 9.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
23dec8f68479d5f4c2273d16ef7a70c18ea403eeb9b1ea759870aff40b2627e7
|
|
| MD5 |
0548b9f825ba342858e8dfde7cd49400
|
|
| BLAKE2b-256 |
2ff503835440ff29157e8d3a2bcc803d3a7a4014a8c3b0599910d09c97762ecd
|