Batched linear assignment with PyTorch and CUDA.
Project description
Batch linear assignment for PyTorch
Installation | Usage | Citation
Batch computation of the linear assignment problem on GPU.
Install
Build and install via PyPI (source distribution):
pip install torch-linear-assignment
Build and install from Git repository:
pip install .
When building with CUDA, make sure NVCC has the same CUDA version as PyTorch. You can choose CUDA version by
export PATH=/usr/local/cuda-<version>/bin:"$PATH"
If you need custom C++ compiler, use the following command:
CXX=<c++-compiler> CC=<gcc-compiler> pip install .
If you get a torch-not-found error, try the following command:
pip install --upgrade pip wheel setuptools
python -m pip install .
Example
import torch
from torch_linear_assignment import batch_linear_assignment
cost = torch.tensor([
8, 4, 7,
5, 2, 3,
9, 6, 7,
9, 4, 8
]).reshape(1, 4, 3).cuda()
assignment = batch_linear_assignment(cost)
print(assignment)
The output is:
tensor([[ 0, 2, -1, 1]], device='cuda:0')
To get indices in the SciPy's format:
from torch_linear_assignment import assignment_to_indices
row_ind, col_ind = assignment_to_indices(assignment)
print(row_ind)
print(col_ind)
The output is:
tensor([[0, 1, 3]], device='cuda:0')
tensor([[0, 2, 1]], device='cuda:0')
Citation
The code was originally developed for the HoTPP Benchmark. If you use this code in your research project, please cite one of the following papers:
@article{karpukhin2024hotppbenchmark,
title={HoTPP Benchmark: Are We Good at the Long Horizon Events Forecasting?},
author={Karpukhin, Ivan and Shipilov, Foma and Savchenko, Andrey},
journal={arXiv preprint arXiv:2406.14341},
year={2024},
url ={https://arxiv.org/abs/2406.14341}
}
@article{karpukhin2024detpp,
title={DeTPP: Leveraging Object Detection for Robust Long-Horizon Event Prediction},
author={Karpukhin, Ivan and Savchenko, Andrey},
journal={arXiv preprint arXiv:2408.13131},
year={2024},
url ={https://arxiv.org/abs/2408.13131}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file torch_linear_assignment-0.0.1.post3.tar.gz
.
File metadata
- Download URL: torch_linear_assignment-0.0.1.post3.tar.gz
- Upload date:
- Size: 12.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.20
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 634e013e0a33422f4414734b25cedbf01dc1d6908d1b624a7c5bae7cd802a83b |
|
MD5 | 4f251753aa2fe2db33530a892385ad69 |
|
BLAKE2b-256 | 5428f6b067c9f8086b06bac1fc93499773142881602a3c4a10f670f2240849f3 |