Skip to main content

Pairwise Metrics for PyTorch

Project description

TorchPairwise GitHub Workflow Status PyPI - Version Downloads GitHub - License

This package provides highly-efficient pairwise metrics for PyTorch.

News

Highlights

torchpairwise is a collection of general purpose pairwise metric functions that behave similar to torch.cdist (which only implements $L_p$ distance). Instead, we offer a lot more metrics ported from other packages such as scipy.spatial.distance and sklearn.metrics.pairwise. For task-specific metrics (e.g. for evaluation of classification, regression, clustering, ...), you should be in the wrong place, please head to the TorchMetrics repo.

Written in torch's C++ API, the main differences are that our metrics:

  • are all (except some boolean distances) differentiable with backward formulas manually derived, implemented, and verified with torch.autograd.gradcheck.
  • are batched and can exploit GPU parallelization.
  • can be integrated seamlessly within PyTorch-based projects, all functions are torch.jit.script-able.

List of pairwise distance metrics

torchpairwise ops Equivalences in other libraries Differentiable
euclidean_distances sklearn.metrics.pairwise.euclidean_distances ✔️
haversine_distances sklearn.metrics.pairwise.haversine_distances ✔️
manhattan_distances sklearn.metrics.pairwise.manhattan_distances ✔️
cosine_distances sklearn.metrics.pairwise.cosine_distances ✔️
l1_distances (Alias of manhattan_distances) ✔️
l2_distances (Alias of euclidean_distances) ✔️
lp_distances (Alias of minkowski_distances) ✔️
linf_distances (Alias of chebyshev_distances) ✔️
directed_hausdorff_distances scipy.spatial.distance.directed_hausdorff [^1] ✔️
minkowski_distances scipy.spatial.distance.minkowski [^1] ✔️
wminkowski_distances scipy.spatial.distance.wminkowski [^1] ✔️
sqeuclidean_distances scipy.spatial.distance.sqeuclidean_distances [^1] ✔️
correlation_distances scipy.spatial.distance.correlation [^1] ✔️
hamming_distances scipy.spatial.distance.hamming [^1] ❌[^2]
jaccard_distances scipy.spatial.distance.jaccard [^1] ❌[^2]
kulsinski_distances scipy.spatial.distance.kulsinski [^1] ❌[^2]
kulczynski1_distances scipy.spatial.distance.kulczynski1 [^1] ❌[^2]
seuclidean_distances scipy.spatial.distance.seuclidean [^1] ✔️
cityblock_distances scipy.spatial.distance.cityblock [^1] (Alias of manhattan_distances) ✔️
mahalanobis_distances scipy.spatial.distance.mahalanobis [^1] ✔️
chebyshev_distances scipy.spatial.distance.chebyshev [^1] ✔️
braycurtis_distances scipy.spatial.distance.braycurtis [^1] ✔️
canberra_distances scipy.spatial.distance.canberra [^1] ✔️
jensenshannon_distances scipy.spatial.distance.jensenshannon [^1] ✔️
yule_distances scipy.spatial.distance.yule [^1] ❌[^2]
dice_distances scipy.spatial.distance.dice [^1] ❌[^2]
rogerstanimoto_distances scipy.spatial.distance.rogerstanimoto [^1] ❌[^2]
russellrao_distances scipy.spatial.distance.russellrao [^1] ❌[^2]
sokalmichener_distances scipy.spatial.distance.sokalmichener [^1] ❌[^2]
sokalsneath_distances scipy.spatial.distance.sokalsneath [^1] ❌[^2]
snr_distances pytorch_metric_learning.distances.SNRDistance [^1] ✔️

[^1]: These metrics are not pairwise but a pairwise form can be computed by calling scipy.spatial.distance.cdist(x1, x2, metric="[metric_name_or_callable]").

[^2]: These are boolean distances. hamming_distances can be applied for floating point inputs but involves comparison.

Other pairwise metrics or kernel functions

These metrics are usually used to compute kernel for machine learning algorithms.

torchpairwise ops Equivalences in other libraries Differentiable
linear_kernel sklearn.metrics.pairwise.linear_kernel ✔️
polynomial_kernel sklearn.metrics.pairwise.polynomial_kernel ✔️
sigmoid_kernel sklearn.metrics.pairwise.sigmoid_kernel ✔️
rbf_kernel sklearn.metrics.pairwise.rbf_kernel ✔️
laplacian_kernel sklearn.metrics.pairwise.laplacian_kernel ✔️
cosine_similarity sklearn.metrics.pairwise.cosine_similarity ✔️
additive_chi2_kernel sklearn.metrics.pairwise.additive_chi2_kernel ✔️
chi2_kernel sklearn.metrics.pairwise.chi2_kernel ✔️

Custom cdist and pdist

Furthermore, we provide a convenient wrapper function analoguous to torch.cdist excepts that it takes a string metric: str = "minkowski" indicating the desired metric to be used as the third argument, and extra metric-specific arguments are passed as keywords.

import torch, torchpairwise

# directed_hausdorff_distances is a pairwise 2d metric
x1 = torch.rand(10, 6, 3)
x2 = torch.rand(8, 5, 3)

generator = torch.Generator().manual_seed(1)
output = torchpairwise.cdist(x1, x2,
                             metric="directed_hausdorff",
                             shuffle=True,  # kwargs exclusive to directed_hausdorff
                             generator=generator)

Note that pairwise metrics on the second table are currently not allowed keys for cdist because they are not dist. We have a similar plan for pdist (which is equivalent to calling cdist(x1, x1) but avoid storing duplicated positions). However, that requires a total overhaul of existing C++/Cuda kernels and won't be available soon.

Future Improvements

  • Add more metrics (contact me or create a feature request issue).
  • Add memory-efficient argkmin for retrieving pairwise neighbors' distances and indices without storing the whole pairwise distance matrix.
  • Add an equivalence of torch.pdist with metric: str = "minkowski" argument.
  • (Unlikely) Support sparse layouts.

Requirements

  • torch>=2.1.0 (torch>=1.9.0 if compiled from source)

Installation

From PyPI:

To install prebuilt wheels from torchpairwise, simply run:

pip install torchpairwise

Note that the Linux and Windows wheels in PyPI are compiled with torch==2.1.0 and Cuda 12.1. We only do a non-strict version checking and a warning will be raised if torch's and torchpairwise's Cuda versions do not match.

From Source:

Make sure your machine has a C++17 and a Cuda compiler installed, then clone the repo and run:

pip install .

Usage

The basic usecase is very straight-forward if you are familiar with sklearn.metrics.pairwise and scipy.spatial.distance:

scikit-learn / SciPy TorchPairwise
import numpy as np
import sklearn.metrics.pairwise as sklearn_pairwise

x1 = np.random.rand(10, 5)
x2 = np.random.rand(12, 5)

output = sklearn_pairwise.cosine_similarity(x1, x2)
print(output)
import torch
import torchpairwise

x1 = torch.rand(10, 5, device='cuda')
x2 = torch.rand(12, 5, device='cuda')

output = torchpairwise.cosine_similarity(x1, x2)
print(output)
import numpy as np
import scipy.spatial.distance as distance

x1 = np.random.binomial(
    1, p=0.6, size=(10, 5)).astype(np.bool_)
x2 = np.random.binomial(
    1, p=0.7, size=(12, 5)).astype(np.bool_)

output = distance.cdist(x1, x2, metric='jaccard')
print(output)
import torch
import torchpairwise

x1 = torch.bernoulli(
    torch.full((10, 5), fill_value=0.6, device='cuda')).to(torch.bool)
x2 = torch.bernoulli(
    torch.full((12, 5), fill_value=0.7, device='cuda')).to(torch.bool)

output = torchpairwise.jaccard_distances(x1, x2)
print(output)

Please check the tests folder where we will add more examples.

License

The code is released under the MIT license. See LICENSE.txt for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchpairwise-0.1.1.tar.gz (50.4 kB view details)

Uploaded Source

Built Distributions

torchpairwise-0.1.1-cp311-cp311-win_amd64.whl (16.4 MB view details)

Uploaded CPython 3.11 Windows x86-64

torchpairwise-0.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.4 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

torchpairwise-0.1.1-cp311-cp311-macosx_10_9_x86_64.whl (757.5 kB view details)

Uploaded CPython 3.11 macOS 10.9+ x86-64

torchpairwise-0.1.1-cp310-cp310-win_amd64.whl (16.4 MB view details)

Uploaded CPython 3.10 Windows x86-64

torchpairwise-0.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.4 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

torchpairwise-0.1.1-cp310-cp310-macosx_10_9_x86_64.whl (757.5 kB view details)

Uploaded CPython 3.10 macOS 10.9+ x86-64

torchpairwise-0.1.1-cp39-cp39-win_amd64.whl (16.5 MB view details)

Uploaded CPython 3.9 Windows x86-64

torchpairwise-0.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.4 MB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

torchpairwise-0.1.1-cp39-cp39-macosx_10_9_x86_64.whl (757.5 kB view details)

Uploaded CPython 3.9 macOS 10.9+ x86-64

torchpairwise-0.1.1-cp38-cp38-win_amd64.whl (16.5 MB view details)

Uploaded CPython 3.8 Windows x86-64

torchpairwise-0.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.4 MB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ x86-64

torchpairwise-0.1.1-cp38-cp38-macosx_10_9_x86_64.whl (757.5 kB view details)

Uploaded CPython 3.8 macOS 10.9+ x86-64

File details

Details for the file torchpairwise-0.1.1.tar.gz.

File metadata

  • Download URL: torchpairwise-0.1.1.tar.gz
  • Upload date:
  • Size: 50.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for torchpairwise-0.1.1.tar.gz
Algorithm Hash digest
SHA256 ee8a766e361ec5364df84160748debd1da4c9ec9408d4a6b74060bfeeda78b8a
MD5 e8e3bd32d6a0fcf5828592851ba5365e
BLAKE2b-256 c3a2ab7a2ed4e179b681223ba43fb1cf4ab55b3970182f641ce49d4afe4a8594

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 26a555ac644f45db2928fc3e3c685cf3bc379910cf524276656b576877c98d08
MD5 f216630e014f077ca307b290cb99cecb
BLAKE2b-256 214afd642ef680974691b3dbe4d2a8f2a906b074fabe1fd16dc5f77e6219e785

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 62a45a155a56e2f46e89ac1870dae7ce87102f421c84d0ea0a2df2af6c2afd4f
MD5 edefe85482e8a6dc2bdc03880aa539cf
BLAKE2b-256 24dfa68ee15e46d590872d3e7f2cb1cddcb04abdc79fc5438946a5cb352a11ed

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp311-cp311-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp311-cp311-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 7bad619bd7888e1a8561758f981af1fa9f447fdf453378dcc227f1e206567de5
MD5 4a75dce6adeb4c6f6fdba7f84e3ae437
BLAKE2b-256 95bbf3033c06c24a1158f04ed38b6900c354259ccc4c5a5b8dd89dd49365cad0

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 29f6171a0650522864aae5701b9234156ad197938679c0ead54e07c1b2835c7b
MD5 967d4b565c94edc9b32248c2e21ced63
BLAKE2b-256 fc58fce93960091c4c77f4ea05e3b2b8201c70116d2edcaada5d104771db1027

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ecbbe463f184dee1fba97d5820cec6cf8397ea38420d7ab9328036e8c78f27a8
MD5 c28887ba903236d37b5fe134e54624d4
BLAKE2b-256 5031f3777fec5a33481a487d11799370450766d9ba4d85f720e5aee3896dc9c6

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp310-cp310-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp310-cp310-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 1d104d18e02c6a5fe11bdc6751cd0aac6c597287f5dbb3b4ab1a17f528dbbb2b
MD5 430550290f199dc0bc19268458a9d677
BLAKE2b-256 1b0ca7a41fc8d421f13062a0b413ced31ed4ce320e237301c2a8e478615876e7

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 905cc6671cdd3641316d00a7c40c25d22988e49ec57a96a1236d16903ecc2038
MD5 d3a9637d3a800807643f019498f5ef7b
BLAKE2b-256 7601136fbac0c6af1680c36d0fd8b3520d5de9c253cffff68a10cc962e107180

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 dd8c18c39b1d1301fd54fdb306b90e37dda6680441a8d5340b00501067762658
MD5 f8408827fdc4f7b1f7f614a2c455bc01
BLAKE2b-256 bf2f00b28d4fe8a6297aca6e14cc9a96a0a37c7b46bd9b5f0f2a3e3cb2f694ee

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp39-cp39-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp39-cp39-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 a568f0e7facb6c4034422f732d2a535563e78b43c4d292dbe5c964af6498b120
MD5 b2575eee3330adfdf3e0aa84979ad90e
BLAKE2b-256 5e1825ba537df84b37b068bcd0c17d8b8fdf6fdf4a3876fe6bcf43522cc66d25

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 42cac2c1b73496c7b7349197a68a52a3e36037f9ed20d4d0ab81ef8f19ef1e8a
MD5 6a10cd26152598ee2ee3a922546dcdd0
BLAKE2b-256 834d1505accd1b931edac22b68047865b121135ecc83b341a06b73914be90504

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 32fcb746485e4d139bd2b92fd3184287db0cb8421109ffe9e9c7285abf35bf28
MD5 c246a1c18eb3d049d0f7a10fb64f9cd5
BLAKE2b-256 8353f174de1762dfe70b3df61b5f3c08f2f62fcbee1ddb794cb5ae9b056729aa

See more details on using hashes here.

File details

Details for the file torchpairwise-0.1.1-cp38-cp38-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for torchpairwise-0.1.1-cp38-cp38-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 416d46a1b1e615d0fe31dfafadcec72cc8b4a2834ac9833288214870ba3ef1ba
MD5 c1caa28c8f0a7238f0dd22d29f395466
BLAKE2b-256 4ac6e604bfb200cf064eef4de6cf9872c3e7a1195e75829aa94ca7606cf93274

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page