Skip to main content

Batched linear assignment with PyTorch and CUDA.

Project description

Batch linear assignment for PyTorch

PyPI version Build Status Downloads License

Installation | Usage | Citation

Batch computation of the linear assignment problem on GPU.

Install

Build and install via PyPI (source distribution):

pip install torch-linear-assignment

Build and install from Git repository:

pip install .

Building in an isolated environment may use a different PyTorch version. To match the current environment and reduce the disk usage, apply the following flag:

pip install --no-build-isolation torch-linear-assignment

When building with CUDA, make sure NVCC has the same CUDA version as PyTorch. You can choose CUDA version by

export PATH=/usr/local/cuda-<version>/bin:"$PATH"

If you need custom C++ compiler, use the following command:

CXX=<c++-compiler> CC=<gcc-compiler> pip install .

If you get a torch-not-found error, try the following command:

pip install --upgrade pip wheel setuptools
python -m pip install .

Example

import torch
from torch_linear_assignment import batch_linear_assignment

cost = torch.tensor([
    8, 4, 7,
    5, 2, 3,
    9, 6, 7,
    9, 4, 8
]).reshape(1, 4, 3).cuda()

assignment = batch_linear_assignment(cost)
print(assignment)

The output is:

tensor([[ 0,  2, -1,  1]], device='cuda:0')

To get indices in the SciPy's format:

from torch_linear_assignment import assignment_to_indices

row_ind, col_ind = assignment_to_indices(assignment)
print(row_ind)
print(col_ind)

The output is:

tensor([[0, 1, 3]], device='cuda:0')
tensor([[0, 2, 1]], device='cuda:0')

Citation

The code was originally developed for the HoTPP Benchmark. If you use this code in your research project, please cite one of the following papers:

@article{karpukhin2024hotppbenchmark,
  title={HoTPP Benchmark: Are We Good at the Long Horizon Events Forecasting?},
  author={Karpukhin, Ivan and Shipilov, Foma and Savchenko, Andrey},
  journal={arXiv preprint arXiv:2406.14341},
  year={2024},
  url ={https://arxiv.org/abs/2406.14341}
}

@article{karpukhin2024detpp,
  title={DeTPP: Leveraging Object Detection for Robust Long-Horizon Event Prediction},
  author={Karpukhin, Ivan and Savchenko, Andrey},
  journal={arXiv preprint arXiv:2408.13131},
  year={2024},
  url ={https://arxiv.org/abs/2408.13131}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_linear_assignment-0.0.5.tar.gz (12.4 kB view details)

Uploaded Source

File details

Details for the file torch_linear_assignment-0.0.5.tar.gz.

File metadata

  • Download URL: torch_linear_assignment-0.0.5.tar.gz
  • Upload date:
  • Size: 12.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.23

File hashes

Hashes for torch_linear_assignment-0.0.5.tar.gz
Algorithm Hash digest
SHA256 8f0e92021548bd702e5dc9ce83e1b5a460f69ec1224961df8928050ea43a0f89
MD5 8b5450cfc24874ecc8cd0fa5a316f27b
BLAKE2b-256 962360a98daf5baa31733960e35318452e431b9ba21e15b659813496c4f257f2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page