An extension of torchmetrics package.
Project description
TorchMetrics Extension
Installation
Simple installation from PyPI
pip install torchmetrics-ext
What is TorchMetrics Extension
It is an extension of torchmetrics containing more metrics for machine learning tasks. It offers:
- A standardized interface to increase reproducibility
- Reduces Boilerplate
- Distributed-training compatible
- Rigorously tested
- Automatic accumulation over batches
- Automatic synchronization between multiple devices
Currently, it offers metrics for:
- 3D Visual Grounding
- 3D Object Detection
- ScanNet (Under development)
Using TorchMetrics Extension
Here are examples for using the metrics in TorchMetrics Extension:
ScanRefer
Please download the ScanRefer dataset first, which will be required by the evaluator.
It measures the thresholded accuracy Acc@kIoU, where the positive predictions have higher intersection over union (IoU) with the ground truths than the thresholds. The metric is based on the ScanRefer task.
import torch
from torchmetrics_ext.metrics.visual_grounding import ScanReferMetric
metric = ScanReferMetric(dataset_file_path="./ScanRefer_filtered_val.json", split="validation")
# preds is a dictionary mapping each unique description identifier (formatted as "{scene_id}_{object_id}_{ann_id}")
# to the predicted axis-aligned bounding boxes in shape (2, 3)
preds = {
"scene0011_00_0_0": torch.tensor([[0., 0., 0.], [0.5, 0.5, 0.5]]),
"scene0011_01_0_1": torch.tensor([[0., 0., 0.], [1., 1., 1.]]),
}
metric(preds)
Nr3D
The dataset will be automatically downloaded from the official Nr3D Google Drive.
It measures the accuracy of selecting the target object from the candidates. The metric is based on the Nr3D task.
import torch
from torchmetrics_ext.metrics.visual_grounding import Nr3DMetric
metric = Nr3DMetric(split="test")
# indices of predicted and ground truth objects (B, )
pred_indices = torch.tensor([5, 2, 0, 0], dtype=torch.uint8)
gt_indices = torch.tensor([5, 5, 1, 0], dtype=torch.uint8)
gt_eval_types = (("easy", "view_dep"), ("easy", "view_indep"), ("hard", "view_dep"), ("hard", "view_dep"))
results = metric(pred_indices, gt_indices, gt_eval_types)
Multi3DRefer
The dataset will be automatically downloaded from the official Multi3DRefer Hugging Face repo.
It measures the F1-scores at multiple IoU thresholds (F1@kIoU), where the positive predictions have higher intersection over union (IoU) with the ground truths than the thresholds. The metric is based on the Multi3DRefer task.
import torch
from torchmetrics_ext.metrics.visual_grounding import Multi3DReferMetric
metric = Multi3DReferMetric(split="validation")
# preds is a dictionary mapping each unique description identifier (formatted as "{scene_id}_{ann_id}")
# to a variable number of predicted axis-aligned bounding boxes in shape (N, 2, 3)
preds = {
"scene0011_00_0": torch.tensor([[[0., 0., 0.], [0.5, 0.5, 0.5]]]), # 1 predicted box
"scene0011_01_1": torch.tensor([[[0., 0., 0.], [1., 1., 1.]], [[0., 0., 0.], [2., 2., 2.]]]) # 2 predicted boxes
}
result = metric(preds)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchmetrics_ext-0.3.1.tar.gz.
File metadata
- Download URL: torchmetrics_ext-0.3.1.tar.gz
- Upload date:
- Size: 17.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.24
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fe1153e372867af86e3a6b139f9120063fcc7772ee64944c2c4714420ae414db
|
|
| MD5 |
ab59aed47ab41722a419a3013c7e5fee
|
|
| BLAKE2b-256 |
f530aea45523c1d5f1d9544833c456ab3b5d66eedc7642c632c64e95227f4b34
|
File details
Details for the file torchmetrics_ext-0.3.1-py3-none-any.whl.
File metadata
- Download URL: torchmetrics_ext-0.3.1-py3-none-any.whl
- Upload date:
- Size: 21.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.24
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e8da135b545a7c0303271515a20dc5176c354d7811ca7eae978736270b37849f
|
|
| MD5 |
3a879924b33321144b5d4e623a317515
|
|
| BLAKE2b-256 |
29722299ace987a902b2b0f5c7a354318999f798eeb62ad5b0a6afc78057fd7b
|