torch-kde

A differentiable implementation of kernel density estimation in PyTorch.

https://github.com/rudolfwilliam/torch-kde

Science Score: 67.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 2 DOI reference(s) in README
  • Academic publication links
    Links to: arxiv.org, zenodo.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (15.1%) to scientific vocabulary

Keywords

kernel-density-estimation python pytorch scikit-learn
Last synced: 6 months ago · JSON representation ·

Repository

A differentiable implementation of kernel density estimation in PyTorch.

Basic Info
  • Host: GitHub
  • Owner: rudolfwilliam
  • License: mit
  • Language: Jupyter Notebook
  • Default Branch: master
  • Homepage:
  • Size: 635 KB
Statistics
  • Stars: 10
  • Watchers: 2
  • Forks: 3
  • Open Issues: 1
  • Releases: 6
Topics
kernel-density-estimation python pytorch scikit-learn
Created about 1 year ago · Last pushed 8 months ago
Metadata Files
Readme License Citation

README.md

TorchKDE :fire:

Python Version PyTorch Version Tests DOI License

A differentiable implementation of kernel density estimation in PyTorch.

$$\hat{f}(x) = \frac{1}{|H|^{\frac{1}{2}} n} \sum{i=1}^n K \left( H^{-\frac{1}{2}} \left( x - xi \right) \right)$$

Installation Instructions

The torch-kde package can be installed via pip. Run

bash pip install torch-kde

Now you are ready to go. If you would also like to run the code from the Jupyter notebooks or contribute to this package, please also install the packages in the requirements.txt:

bash pip install -r requirements.txt

What is included?

Kernel Density Estimation

The KernelDensity class supports the same operations as the KernelDensity class in scikit-learn, but implemented in PyTorch and differentiable with respect to input data. Here is a little taste:

```python from torchkde import KernelDensity import torch

multivariatenormal = torch.distributions.MultivariateNormal(torch.ones(2), torch.eye(2)) X = multivariatenormal.sample((1000,)) # create data X.requires_grad = True # enable differentiation kde = KernelDensity(bandwidth=1.0, kernel='gaussian') # create kde object with isotropic bandwidth matrix _ = kde.fit(X) # fit kde to data

Xnew = multivariatenormal.sample((100,)) # create new data logprob = kde.scoresamples(Xnew)

logprob.grad_fn # is not None ```

You may also check out demo_kde.ipynb for a simple demo on the Bart Simpson distribution, which yields the following density estimate:

Tophat Kernel Approximation

The Tophat kernel is not differentiable at two points and has zero derivative everywhere else. Thus, we provide a differentiable approximation via a generalized Gaussian (see e.g. Pascal et al. for reference):

$$K^{\text{tophat}}(x; \beta) = \frac{\beta \Gamma \left( \frac{p}{2} \right) }{\pi^{\frac{p}{2}} \Gamma \left( \frac{p}{2\beta} \right) 2^{\frac{p}{2\beta}}} \text{exp} \left( - \frac{| x |_2^{2\beta}}{2} \right),$$

where $p$ is the dimensionality of $x$. Based on this kernel, we can approximate the Tophat kernel for large values of $\beta$, as shown in the following 1-dimensional example:

We note that for $\beta = 1$, this approximation corresponds to a Gaussian kernel. Also, while the approximation becomes better for large values of $\beta$, its gradients with respect to the input also become larger. This is a tradeoff that must be balanced when using this kernel.

Supported Settings

The current implementation provides the following functionality:

| Feature | Supported Values | |--------------------------|-----------------------------| | Kernels | Gaussian, Epanechnikov, Exponential, Tophat Approximation, von Mises-Fisher (data must lie on the unit sphere) | | Tree Algorithms | Standard | | Bandwidths | Float (Isotropic bandwidth matrix), Scott, Silverman | | Devices | CPU, GPU |

Would You like to Contribute? Create a Pull Request!

In case you do not know how to do that, here are the necessary steps:

  1. Fork the repo
  2. Create your feature branch (git checkout -b cool_tree_algorithm)
  3. Run the unit tests (python -m tests.test_kde) and only proceed if the script outputs "OK".
  4. Commit your changes (git commit -am 'Add cool tree algorithm')
  5. Push to the branch (git push origin cool_tree_algorithm)
  6. Open a Pull Request

Issues?

If you discover a bug or do not understand something, please create an issue or let me know directly at kkladny [at] tuebingen [dot] mpg [dot] de! I am also happy to take requests for implementing specific functionalities.

> "In God we trust. All others must bring data." > > — W. Edwards Deming >

Owner

  • Name: Klaus-Rudolf Kladny
  • Login: rudolfwilliam
  • Kind: user
  • Location: Tübingen
  • Company: Max Planck Institute for Intelligent Systems

Data Science student at ETH Zürich. Machine Learning 🤖 enthusiast. Currently at the Max Planck Institute for Intelligent Systems.

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
  - family-names: Kladny
    given-names: Klaus-Rudolf
title: "TorchKDE"
version: v0.1.0
identifiers:
  - type: doi
    value: 10.5281/zenodo.14674657
date-released: 2025-01-16
url: "https://github.com/rudolfwilliam/torch-kde"

GitHub Events

Total
  • Create event: 8
  • Issues event: 15
  • Release event: 7
  • Watch event: 14
  • Issue comment event: 8
  • Public event: 1
  • Push event: 54
  • Pull request event: 2
  • Fork event: 3
Last Year
  • Create event: 8
  • Issues event: 15
  • Release event: 7
  • Watch event: 14
  • Issue comment event: 8
  • Public event: 1
  • Push event: 54
  • Pull request event: 2
  • Fork event: 3

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 7
  • Total pull requests: 1
  • Average time to close issues: 5 days
  • Average time to close pull requests: N/A
  • Total issue authors: 4
  • Total pull request authors: 1
  • Average comments per issue: 0.29
  • Average comments per pull request: 0.0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 7
  • Pull requests: 1
  • Average time to close issues: 5 days
  • Average time to close pull requests: N/A
  • Issue authors: 4
  • Pull request authors: 1
  • Average comments per issue: 0.29
  • Average comments per pull request: 0.0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • rudolfwilliam (4)
  • siemdejong (1)
  • braindevices (1)
  • marcosspsljr (1)
Pull Request Authors
  • tuchris (1)
Top Labels
Issue Labels
Pull Request Labels

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 141 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 0
  • Total versions: 5
  • Total maintainers: 1
pypi.org: torch-kde

A differentiable implementation of kernel density estimation in PyTorch

  • Versions: 5
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 141 Last month
Rankings
Dependent packages count: 9.6%
Average: 32.0%
Dependent repos count: 54.3%
Maintainers (1)
Last synced: 6 months ago

Dependencies

pyproject.toml pypi
  • numpy >=1.21.0
  • requests >=2.26.0
.github/workflows/ci.yml actions
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
requirements.txt pypi
  • Jinja2 ==3.1.4
  • MarkupSafe ==3.0.2
  • Pygments ==2.18.0
  • SciencePlots ==2.1.1
  • asttokens ==3.0.0
  • comm ==0.2.2
  • contourpy ==1.3.1
  • cycler ==0.12.1
  • debugpy ==1.8.9
  • decorator ==5.1.1
  • executing ==2.1.0
  • filelock ==3.16.1
  • fonttools ==4.55.2
  • fsspec ==2024.10.0
  • ipykernel ==6.29.5
  • ipython ==8.30.0
  • jedi ==0.19.2
  • jupyter_client ==8.6.3
  • jupyter_core ==5.7.2
  • kiwisolver ==1.4.7
  • matplotlib ==3.9.3
  • matplotlib-inline ==0.1.7
  • mpmath ==1.3.0
  • nest-asyncio ==1.6.0
  • networkx ==3.4.2
  • numpy ==2.2.0
  • packaging ==24.2
  • parso ==0.8.4
  • pexpect ==4.9.0
  • pillow ==11.0.0
  • platformdirs ==4.3.6
  • prompt_toolkit ==3.0.48
  • psutil ==6.1.0
  • ptyprocess ==0.7.0
  • pure_eval ==0.2.3
  • pyparsing ==3.2.0
  • python-dateutil ==2.9.0.post0
  • pyzmq ==26.2.0
  • scipy ==1.14.1
  • six ==1.17.0
  • stack-data ==0.6.3
  • sympy ==1.13.1
  • torch ==2.5.1
  • tornado ==6.4.2
  • traitlets ==5.14.3
  • triton ==3.1.0
  • typing_extensions ==4.12.2
  • wcwidth ==0.2.13