Skip to main content

An extension of torchmetrics package.

Project description

TorchMetrics Extension

PyPI - Python Version PyPI version license

Installation

Simple installation from PyPI

pip install torchmetrics-ext

What is TorchMetrics Extension

It is an extension of torchmetrics containing more metrics for machine learning tasks. It offers:

  • A standardized interface to increase reproducibility
  • Reduces Boilerplate
  • Distributed-training compatible
  • Rigorously tested
  • Automatic accumulation over batches
  • Automatic synchronization between multiple devices

Currently, it offers metrics for:

Using TorchMetrics Extension

Here are examples for using the metrics in TorchMetrics Extension:

ScanRefer

Please download the ScanRefer dataset first, which will be required by the evaluator.

It measures the thresholded accuracy Acc@kIoU, where the positive predictions have higher intersection over union (IoU) with the ground truths than the thresholds. The metric is based on the ScanRefer task.

import torch
from torchmetrics_ext.metrics.visual_grounding import ScanReferMetric
metric = ScanReferMetric(dataset_file_path="./ScanRefer_filtered_val.json", split="validation")

# preds is a dictionary mapping each unique description identifier (formatted as "{scene_id}_{object_id}_{ann_id}")
# to the predicted axis-aligned bounding boxes in shape (2, 3)
preds = {
    "scene0011_00_0_0": torch.tensor([[0., 0., 0.], [0.5, 0.5, 0.5]]),
    "scene0011_01_0_1": torch.tensor([[0., 0., 0.], [1., 1., 1.]]),
    ...
}
metric(preds)

Nr3D

The dataset will be automatically downloaded from the official Nr3D Google Drive.

It measures the accuracy of selecting the target object from the candidates. The metric is based on the Nr3D task.

import torch
from torchmetrics_ext.metrics.visual_grounding import Nr3DMetric

metric = Nr3DMetric(split="test")

# indices of predicted and ground truth objects (B, )
pred_indices = torch.tensor([5, 2, 0, 0], dtype=torch.uint8)
gt_indices = torch.tensor([5, 5, 1, 0], dtype=torch.uint8)

gt_eval_types = (("easy", "view_dep"), ("easy", "view_indep"), ("hard", "view_dep"), ("hard", "view_dep"))
results = metric(pred_indices, gt_indices, gt_eval_types)

Multi3DRefer

The dataset will be automatically downloaded from the official Multi3DRefer Hugging Face repo.

It measures the F1-scores at multiple IoU thresholds (F1@kIoU), where the positive predictions have higher intersection over union (IoU) with the ground truths than the thresholds. The metric is based on the Multi3DRefer task.

import torch
from torchmetrics_ext.metrics.visual_grounding import Multi3DReferMetric
metric = Multi3DReferMetric(split="validation")

# preds is a dictionary mapping each unique description identifier (formatted as "{scene_id}_{ann_id}")
# to a variable number of predicted axis-aligned bounding boxes in shape (N, 2, 3)
preds = {
    "scene0011_00_0": torch.tensor([[[0., 0., 0.], [0.5, 0.5, 0.5]]]),  # 1 predicted box
    "scene0011_01_1": torch.tensor([[[0., 0., 0.], [1., 1., 1.]], [[0., 0., 0.], [2., 2., 2.]]])  # 2 predicted boxes
    ...
}
result = metric(preds)

ReVSI

The dataset will be automatically downloaded from the official ReVSI Hugging Face repo.

import torch
from torchmetrics_ext.metrics.vqa import ReVSIMetric

metric = ReVSIMetric(subset="all_frame")

# preds is a dictionary mapping each unique question identifier "id" to a predicted answer
preds = {
    "1": "3"
    "100": "A",
    ...
}
result = metric(preds)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmetrics_ext-0.3.2.tar.gz (20.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchmetrics_ext-0.3.2-py3-none-any.whl (26.5 kB view details)

Uploaded Python 3

File details

Details for the file torchmetrics_ext-0.3.2.tar.gz.

File metadata

  • Download URL: torchmetrics_ext-0.3.2.tar.gz
  • Upload date:
  • Size: 20.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for torchmetrics_ext-0.3.2.tar.gz
Algorithm Hash digest
SHA256 d7d74d52cbb24785d8dd5ef6b6e06cd99e1b6e23796bc6891d5af1abe6a4360f
MD5 bf8cb6714e97bf69cdd8cf27937a5aea
BLAKE2b-256 0277853b5325f231e01e6d87d804f4846110b3eb3d68d32022133a465132c299

See more details on using hashes here.

File details

Details for the file torchmetrics_ext-0.3.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torchmetrics_ext-0.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c88dc5a3547bfeca8e889c7f73062e0fcbb4a4b78f8218ba2eb9925f0a87f783
MD5 650ada06b4d68b5cef2dc145f4c6e60b
BLAKE2b-256 137d961dead773967e552e38bf80430a69092c03a0e88fc8593952ed74a1a064

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page