Skip to main content

Batched linear assignment with PyTorch and CUDA.

Project description

Batch linear assignment for PyTorch

PyPI version Build Status Downloads License

Installation | Usage | Citation

Batch computation of the linear assignment problem on GPU.

Install

Build and install via PyPI (source distribution):

pip install torch-linear-assignment

Build and install from Git repository:

pip install .

When building with CUDA, make sure NVCC has the same CUDA version as PyTorch. You can choose CUDA version by

export PATH=/usr/local/cuda-<version>/bin:"$PATH"

If you need custom C++ compiler, use the following command:

CXX=<c++-compiler> CC=<gcc-compiler> pip install .

If you get a torch-not-found error, try the following command:

pip install --upgrade pip wheel setuptools
python -m pip install .

Example

import torch
from torch_linear_assignment import batch_linear_assignment

cost = torch.tensor([
    8, 4, 7,
    5, 2, 3,
    9, 6, 7,
    9, 4, 8
]).reshape(1, 4, 3).cuda()

assignment = batch_linear_assignment(cost)
print(assignment)

The output is:

tensor([[ 0,  2, -1,  1]], device='cuda:0')

To get indices in the SciPy's format:

from torch_linear_assignment import assignment_to_indices

row_ind, col_ind = assignment_to_indices(assignment)
print(row_ind)
print(col_ind)

The output is:

tensor([[0, 1, 3]], device='cuda:0')
tensor([[0, 2, 1]], device='cuda:0')

Citation

The code was originally developed for the HoTPP Benchmark. If you use this code in your research project, please cite one of the following papers:

@article{karpukhin2024hotppbenchmark,
  title={HoTPP Benchmark: Are We Good at the Long Horizon Events Forecasting?},
  author={Karpukhin, Ivan and Shipilov, Foma and Savchenko, Andrey},
  journal={arXiv preprint arXiv:2406.14341},
  year={2024},
  url ={https://arxiv.org/abs/2406.14341}
}

@article{karpukhin2024detpp,
  title={DeTPP: Leveraging Object Detection for Robust Long-Horizon Event Prediction},
  author={Karpukhin, Ivan and Savchenko, Andrey},
  journal={arXiv preprint arXiv:2408.13131},
  year={2024},
  url ={https://arxiv.org/abs/2408.13131}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_linear_assignment-0.0.1.post3.tar.gz (12.0 kB view details)

Uploaded Source

File details

Details for the file torch_linear_assignment-0.0.1.post3.tar.gz.

File metadata

File hashes

Hashes for torch_linear_assignment-0.0.1.post3.tar.gz
Algorithm Hash digest
SHA256 634e013e0a33422f4414734b25cedbf01dc1d6908d1b624a7c5bae7cd802a83b
MD5 4f251753aa2fe2db33530a892385ad69
BLAKE2b-256 5428f6b067c9f8086b06bac1fc93499773142881602a3c4a10f670f2240849f3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page