Skip to main content

Differentiable sorting and ranking in PyTorch

Project description

Torchsort

Tests

Fast, differentiable sorting and ranking in PyTorch.

Pure PyTorch implementation of Fast Differentiable Sorting and Ranking (Blondel et al.). Much of the code is copied from the original Numpy implementation at google-research/fast-soft-sort, with the isotonic regression solver rewritten as a PyTorch C++ Extension.

NOTE: I am actively working on this. The API should remain about the same; but expect more optimizations and benchmarks soon. The C++ isotonic regression solver is currently only implemented on CPU, so CUDA tensors will be copied over to CPU to perform the operations. I am currently working on the CUDA kernel implementation, which should be done soon.

Install

pip install torchsort

Usage

torchsort exposes two functions: soft_rank and soft_sort, each with parameters regularization ("l2" or "kl") and regularization_strength (a scalar value). Each will rank/sort the last dimension of a 2-d tensor, with an accuracy dependant upon the regularization strength:

import torch
import torchsort

x = torch.tensor([[8, 0, 5, 3, 2, 1, 6, 7, 9]])

torchsort.soft_sort(x, regularization_strength=1.0)
# tensor([[0.5556, 1.5556, 2.5556, 3.5556, 4.5556, 5.5556, 6.5556, 7.5556, 8.5556]])
torchsort.soft_sort(x, regularization_strength=0.1)
# tensor([[-0., 1., 2., 3., 5., 6., 7., 8., 9.]])

torchsort.soft_rank(x)
# tensor([[8., 1., 5., 4., 3., 2., 6., 7., 9.]])

Both operations are fully differentiable, on CPU or GPU:

x = torch.tensor([[8., 0., 5., 3., 2., 1., 6., 7., 9.]], requires_grad=True).cuda()
y = torchsort.soft_sort(x)

torch.autograd.grad(y[0, 0], x)
# (tensor([[0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111]],
#         device='cuda:0'),)

Benchmark

Benchmark

torchsort and fast_soft_sort each operate with a time complexity of O(n log n), each with some additional overhead when compared to the built-in torch.sort. The Numba JIT'd forward pass of fast_soft_sort performs about on-par with the torchsort CPU kernel, however its backward pass still relies on some Python code, which greatly penalizes its performance. CUDA kernel is coming soon!

Reference

Please site the original paper:

@inproceedings{blondel2020fast,
  title={Fast differentiable sorting and ranking},
  author={Blondel, Mathieu and Teboul, Olivier and Berthet, Quentin and Djolonga, Josip},
  booktitle={International Conference on Machine Learning},
  pages={950--959},
  year={2020},
  organization={PMLR}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchsort-0.0.4.tar.gz (8.4 kB view details)

Uploaded Source

File details

Details for the file torchsort-0.0.4.tar.gz.

File metadata

  • Download URL: torchsort-0.0.4.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.54.1 CPython/3.8.5

File hashes

Hashes for torchsort-0.0.4.tar.gz
Algorithm Hash digest
SHA256 6e182639e29867c6cdd967a35f14dc8a4c70eeac5644c37c4d7bab8bb9ed9a07
MD5 1c296b5494f4115fae13501e35000186
BLAKE2b-256 676b0a556496622c092503fd779cda0b95ded2327c0f8f07cb815cb7edfd20f5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page