Skip to main content

Unbalanced Optimal Transport for Object Detection

Project description

Unbalanced Optimal Transport: A Unified Framework for Object Detection

Presentation       Paper       Supplementary       Documentation      

GitHub License PyPI - Downloads PyPI - Version Documentation Status Test Status Build Status

H. De Plaen, P.-F. De Plaen, J. A. K. Suykens, M. Proesmans, T. Tuytelaars, and L. Van Gool, “Unbalanced Optimal Transport: A Unified Framework for Object Detection,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), Jun. 2023, pp. 3198–3207.

This work has be presented at CVPR 2023 in Vancouver, Canada. The paper and additional resources can be found on the following website. The paper is in open access and can also be found on the CVF website as well as on IEEE Xplore.

Different matching strategies. All are particular cases of Unbalanced Optimal Transport

Abstract

TL;DR: We introduce a much more versatile new class of matching strategies unifying many existing ones, as well as being well suited for GPUs.

During training, supervised object detection tries to correctly match the predicted bounding boxes and associated classification scores to the ground truth. This is essential to determine which predictions are to be pushed towards which solutions, or to be discarded. Popular matching strategies include matching to the closest ground truth box (mostly used in combination with anchors), or matching via the Hungarian algorithm (mostly used in anchor-free methods). Each of these strategies comes with its own properties, underlying losses, and heuristics. We show how Unbalanced Optimal Transport unifies these different approaches and opens a whole continuum of methods in between. This allows for a finer selection of the desired properties. Experimentally, we show that training an object detection model with Unbalanced Optimal Transport is able to reach the state-of-the-art both in terms of Average Precision and Average Recall as well as to provide a faster initial convergence. The approach is well suited for GPU implementation, which proves to be an advantage for large-scale models.

Install

PyPI

Using PyPI, it suffices to run pip install uotod. Just rerun this command to update the package to its newest version.

Build From Source

You can also download it directly from the GitHub repository, then build and install it.

git clone --recursive https://github.com/hdeplaen/uotod
cd uotod
python3 -m pip install -r requirements.txt
python3 -m setup build
python3 -m pip install

Compiled Acceleration

The package is available on all dsitributions and runs well. However, only the combinations marked with a green ✅ can take advantage of the compiled version of Sinkhorn's algorithm directly from PyPI. On a not support combination, you may always build it from the source to also have access to Sinkhorn's compiled version of the algorithm. Nevertheless, the PyTorch implementation of Sinkhorn's algorithm is always available (used by default), this only refers to an additional compiled version.

OS Linux MacOS Windows
Python 3.8 ☑️
Python 3.9 ☑️
Python 3.10 ☑️
Python 3.11 ☑️
Python 3.12 ☑️ ☑️
  • ✅: Python implementation + compiled acceleration, both directly from PyPI
  • ☑️: Python implementation directly from PyPI (+ possible compiled acceleration if building from source)

Examples

OT matching with GIoU loss:

from uotod.match import BalancedSinkhorn
from uotod.loss import GIoULoss

ot = BalancedSinkhorn(
    loc_match_module=GIoULoss(reduction="none"),
    background_cost=0.,
)

Hungarian matching (bipartite) with GIoU loss:

from uotod.match import Hungarian
from uotod.loss import GIoULoss

hungarian = Hungarian(
    loc_match_module=GIoULoss(reduction="none"),
    background_cost=0.,
)

Loss from SSD solved with Unbalanced Optimal Transport:

from torch.nn import L1Loss, CrossEntropyLoss

from uotod.match import UnbalancedSinkhorn
from uotod.loss import DetectionLoss, IoULoss

matching_method = UnbalancedSinkhorn(
    cls_match_module=None,  # No classification cost
    loc_match_module=IoULoss(reduction="none"),
    background_cost=0.5,  # Threshold for matching to background
    is_anchor_based=True,  # Use anchor-based matching
    reg_target=1e-3,  # Relax the constraint that each ground truth is matched to exactly one prediction
)

loss = DetectionLoss(
    matching_method=matching_method,
    cls_loss_module=CrossEntropyLoss(reduction="none"),
    loc_loss_module=L1Loss(reduction="none"),
)

preds = ...
targets = ...
anchors = ...

loss_value = loss(preds, targets, anchors)

Loss from DETR solved with Optimal Transport (num_classes=3):

import torch
from torch.nn import L1Loss, CrossEntropyLoss

from uotod.match import BalancedSinkhorn
from uotod.loss import DetectionLoss
from uotod.loss import MultipleObjectiveLoss, GIoULoss, NegativeProbLoss

matching_method = BalancedSinkhorn(
    cls_match_module=NegativeProbLoss(reduction="none"),
    loc_match_module=MultipleObjectiveLoss(
        losses=[GIoULoss(reduction="none"), L1Loss(reduction="none")],
        weights=[1., 5.],
    ),
    background_cost=0.,  # Does not influence the matching when using balanced OT
)

loss = DetectionLoss(
    matching_method=matching_method,
    cls_loss_module=CrossEntropyLoss(
        reduction="none",
        weight=torch.tensor([0.1, 1., 1.])  # down-weight the loss for the no-object class
    ),
    loc_loss_module=MultipleObjectiveLoss(
        losses=[GIoULoss(reduction="none"), L1Loss(reduction="none")],
        weights=[1., 5.],
    ),
)

preds = ...
targets = ...
loss_value = loss(preds, targets)

Color Boxes Dataset

Examples from the Color Boxes Dataset

Citation

If you make any use of our work, please refer to us as:

@InProceedings{De_Plaen_2023_CVPR,
    author    = {De Plaen, Henri and De Plaen, Pierre-Fran\c{c}ois and Suykens, Johan A. K. and Proesmans, Marc and Tuytelaars, Tinne and Van Gool, Luc},
    title     = {Unbalanced Optimal Transport: A Unified Framework for Object Detection},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {3198-3207}
}

Acknowledgements

EU: The research leading to these results has received funding from the European Research Council under the European Union’s Horizon 2020 research and innovation program / ERC Advanced Grant E-DUALITY (787960). This paper reflects only the authors’ views and the Union is not liable for any use that may be made of the contained information. Research Council KUL: Optimization frameworks for deep kernel machines C14/18/068. Flemish Government: FWO: projects: GOA4917N (Deep Restricted Kernel Machines: Methods and Foundations), PhD/Postdoc grant; This research received funding from the Flemish Government (AI Research Program). All the authors are also affiliated to Leuven.AI - KU Leuven institute for AI, B-3000, Leuven, Belgium.

European Union European Research Council Fonds voor Wetenschappelijk Onderzoek Flanders AI KU Leuven

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

uotod-0.3.post2.tar.gz (1.8 MB view details)

Uploaded Source

Built Distribution

uotod-0.3.post2-py3-none-any.whl (52.9 kB view details)

Uploaded Python 3

File details

Details for the file uotod-0.3.post2.tar.gz.

File metadata

  • Download URL: uotod-0.3.post2.tar.gz
  • Upload date:
  • Size: 1.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for uotod-0.3.post2.tar.gz
Algorithm Hash digest
SHA256 d3bdfd3f887d5280e22e89efbe71d80483d59911162565e3ee282dd08643bc29
MD5 a0f7f0c506c109fc696930dcb948f821
BLAKE2b-256 d3fe733285392e909725a2227c42bf65221e33f657d0f1efbe0526feb72626e7

See more details on using hashes here.

File details

Details for the file uotod-0.3.post2-py3-none-any.whl.

File metadata

  • Download URL: uotod-0.3.post2-py3-none-any.whl
  • Upload date:
  • Size: 52.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for uotod-0.3.post2-py3-none-any.whl
Algorithm Hash digest
SHA256 66059564ea27141d301a43e422d070dbaaf010629faa5955f864951553bf5f54
MD5 616013dc2b8c01e95be13f7995e17445
BLAKE2b-256 9e4ca2b2ffbfeab9d498df00fe3797399524cae57f80e821b90726a360b4e30b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page