Skip to main content

Pairwise Metrics for PyTorch

Project description

TorchPairwise GitHub Workflow Status PyPI - Version Downloads GitHub - License

This package provides highly-efficient pairwise metrics for PyTorch.

News

Highlights

torchpairwise is a collection of general purpose pairwise metric functions that behave similar to torch.cdist (which only implements $L_p$ distance). Instead, we offer a lot more metrics ported from other packages such as scipy.spatial.distance and sklearn.metrics.pairwise. For task-specific metrics (e.g. for evaluation of classification, regression, clustering, ...), you should be in the wrong place, please head to the TorchMetrics repo.

Written in torch's C++ API, the main differences are that our metrics:

  • are all (except some boolean distances) differentiable with backward formulas manually derived, implemented, and verified with torch.autograd.gradcheck.
  • are batched and can exploit GPU parallelization.
  • can be integrated seamlessly within PyTorch-based projects, all functions are torch.jit.script-able.

List of pairwise distance metrics

torchpairwise ops Equivalences in other libraries Differentiable
euclidean_distances sklearn.metrics.pairwise.euclidean_distances ✔️
haversine_distances sklearn.metrics.pairwise.haversine_distances ✔️
manhattan_distances sklearn.metrics.pairwise.manhattan_distances ✔️
cosine_distances sklearn.metrics.pairwise.cosine_distances ✔️
l1_distances (Alias of manhattan_distances) ✔️
l2_distances (Alias of euclidean_distances) ✔️
lp_distances (Alias of minkowski_distances) ✔️
linf_distances (Alias of chebyshev_distances) ✔️
directed_hausdorff_distances scipy.spatial.distance.directed_hausdorff [^1] ✔️
minkowski_distances scipy.spatial.distance.minkowski [^1] ✔️
wminkowski_distances scipy.spatial.distance.wminkowski [^1] ✔️
sqeuclidean_distances scipy.spatial.distance.sqeuclidean_distances [^1] ✔️
correlation_distances scipy.spatial.distance.correlation [^1] ✔️
hamming_distances scipy.spatial.distance.hamming [^1] ❌[^2]
jaccard_distances scipy.spatial.distance.jaccard [^1] ❌[^2]
kulsinski_distances scipy.spatial.distance.kulsinski [^1] ❌[^2]
kulczynski1_distances scipy.spatial.distance.kulczynski1 [^1] ❌[^2]
seuclidean_distances scipy.spatial.distance.seuclidean [^1] ✔️
cityblock_distances scipy.spatial.distance.cityblock [^1] (Alias of manhattan_distances) ✔️
mahalanobis_distances scipy.spatial.distance.mahalanobis [^1] ✔️
chebyshev_distances scipy.spatial.distance.chebyshev [^1] ✔️
braycurtis_distances scipy.spatial.distance.braycurtis [^1] ✔️
canberra_distances scipy.spatial.distance.canberra [^1] ✔️
jensenshannon_distances scipy.spatial.distance.jensenshannon [^1] ✔️
yule_distances scipy.spatial.distance.yule [^1] ❌[^2]
dice_distances scipy.spatial.distance.dice [^1] ❌[^2]
rogerstanimoto_distances scipy.spatial.distance.rogerstanimoto [^1] ❌[^2]
russellrao_distances scipy.spatial.distance.russellrao [^1] ❌[^2]
sokalmichener_distances scipy.spatial.distance.sokalmichener [^1] ❌[^2]
sokalsneath_distances scipy.spatial.distance.sokalsneath [^1] ❌[^2]
snr_distances pytorch_metric_learning.distances.SNRDistance [^1] ✔️

[^1]: These metrics are not pairwise but a pairwise form can be computed by calling scipy.spatial.distance.cdist(x1, x2, metric="[metric_name_or_callable]").

[^2]: These are boolean distances. hamming_distances can be applied for floating point inputs but involves comparison.

Other pairwise metrics or kernel functions

These metrics are usually used to compute kernel for machine learning algorithms.

torchpairwise ops Equivalences in other libraries Differentiable
linear_kernel sklearn.metrics.pairwise.linear_kernel ✔️
polynomial_kernel sklearn.metrics.pairwise.polynomial_kernel ✔️
sigmoid_kernel sklearn.metrics.pairwise.sigmoid_kernel ✔️
rbf_kernel sklearn.metrics.pairwise.rbf_kernel ✔️
laplacian_kernel sklearn.metrics.pairwise.laplacian_kernel ✔️
cosine_similarity sklearn.metrics.pairwise.cosine_similarity ✔️
additive_chi2_kernel sklearn.metrics.pairwise.additive_chi2_kernel ✔️
chi2_kernel sklearn.metrics.pairwise.chi2_kernel ✔️

Custom cdist and pdist

Furthermore, we provide a convenient wrapper function analoguous to torch.cdist excepts that it takes a string metric: str = "minkowski" indicating the desired metric to be used as the third argument, and extra metric-specific arguments are passed as keywords.

import torch, torchpairwise

# directed_hausdorff_distances is a pairwise 2d metric
x1 = torch.rand(10, 6, 3)
x2 = torch.rand(8, 5, 3)

generator = torch.Generator().manual_seed(1)
output = torchpairwise.cdist(x1, x2,
                             metric="directed_hausdorff",
                             shuffle=True,  # kwargs exclusive to directed_hausdorff
                             generator=generator)

Note that pairwise metrics on the second table are currently not allowed keys for cdist because they are not dist. We have a similar plan for pdist (which is equivalent to calling cdist(x1, x1) but avoid storing duplicated positions). However, that requires a total overhaul of existing C++/Cuda kernels and won't be available soon.

Future Improvements

  • Add more metrics (contact me or create a feature request issue).
  • Add memory-efficient argkmin for retrieving pairwise neighbors' distances and indices without storing the whole pairwise distance matrix.
  • Add an equivalence of torch.pdist with metric: str = "minkowski" argument.
  • (Unlikely) Support sparse layouts.

Requirements

  • torch>=2.1.0 (torch>=1.9.0 if compiled from source)

Installation

From PyPI:

To install prebuilt wheels from torchpairwise, simply run:

pip install torchpairwise

Note that the Linux and Windows wheels in PyPI are compiled with torch==2.1.0 and Cuda 12.1. We only do a non-strict version checking and a warning will be raised if torch's and torchpairwise's Cuda versions do not match.

From Source:

Make sure your machine has a C++17 and a Cuda compiler installed, then clone the repo and run:

pip install .

Usage

The basic usecase is very straight-forward if you are familiar with sklearn.metrics.pairwise and scipy.spatial.distance:

scikit-learn / SciPy TorchPairwise
import numpy as np
import sklearn.metrics.pairwise as sklearn_pairwise

x1 = np.random.rand(10, 5)
x2 = np.random.rand(12, 5)

output = sklearn_pairwise.cosine_similarity(x1, x2)
print(output)
import torch
import torchpairwise

x1 = torch.rand(10, 5, device='cuda')
x2 = torch.rand(12, 5, device='cuda')

output = torchpairwise.cosine_similarity(x1, x2)
print(output)
import numpy as np
import scipy.spatial.distance as distance

x1 = np.random.binomial(
    1, p=0.6, size=(10, 5)).astype(np.bool_)
x2 = np.random.binomial(
    1, p=0.7, size=(12, 5)).astype(np.bool_)

output = distance.cdist(x1, x2, metric='jaccard')
print(output)
import torch
import torchpairwise

x1 = torch.bernoulli(
    torch.full((10, 5), fill_value=0.6, device='cuda')).to(torch.bool)
x2 = torch.bernoulli(
    torch.full((12, 5), fill_value=0.7, device='cuda')).to(torch.bool)

output = torchpairwise.jaccard_distances(x1, x2)
print(output)

Please check the tests folder where we will add more examples.

License

The code is released under the MIT license. See LICENSE.txt for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchpairwise-0.1.1.tar.gz (50.4 kB view hashes)

Uploaded Source

Built Distributions

torchpairwise-0.1.1-cp311-cp311-win_amd64.whl (16.4 MB view hashes)

Uploaded CPython 3.11 Windows x86-64

torchpairwise-0.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.4 MB view hashes)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

torchpairwise-0.1.1-cp311-cp311-macosx_10_9_x86_64.whl (757.5 kB view hashes)

Uploaded CPython 3.11 macOS 10.9+ x86-64

torchpairwise-0.1.1-cp310-cp310-win_amd64.whl (16.4 MB view hashes)

Uploaded CPython 3.10 Windows x86-64

torchpairwise-0.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.4 MB view hashes)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

torchpairwise-0.1.1-cp310-cp310-macosx_10_9_x86_64.whl (757.5 kB view hashes)

Uploaded CPython 3.10 macOS 10.9+ x86-64

torchpairwise-0.1.1-cp39-cp39-win_amd64.whl (16.5 MB view hashes)

Uploaded CPython 3.9 Windows x86-64

torchpairwise-0.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.4 MB view hashes)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

torchpairwise-0.1.1-cp39-cp39-macosx_10_9_x86_64.whl (757.5 kB view hashes)

Uploaded CPython 3.9 macOS 10.9+ x86-64

torchpairwise-0.1.1-cp38-cp38-win_amd64.whl (16.5 MB view hashes)

Uploaded CPython 3.8 Windows x86-64

torchpairwise-0.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.4 MB view hashes)

Uploaded CPython 3.8 manylinux: glibc 2.17+ x86-64

torchpairwise-0.1.1-cp38-cp38-macosx_10_9_x86_64.whl (757.5 kB view hashes)

Uploaded CPython 3.8 macOS 10.9+ x86-64

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page