Skip to main content

TorchMetrics-compatible predictive uncertainty metrics using MC Dropout

Project description

Dropwise-Metrics

Dropwise-Metrics is a lightweight TorchMetrics-compatible toolkit for performing Monte Carlo Dropout–based uncertainty estimation in Transformers. It enables confidence-aware decision making by revealing how certain a model is about its predictions — packaged as plug-and-play PyTorch Metric classes.


Features

  • Enable dropout during inference for Bayesian-like uncertainty estimation
  • Compute predictive entropy, confidence, and per-class standard deviation
  • Modular support for classification, QA, token tagging, and regression
  • Works seamlessly with Hugging Face Transformers and PyTorch
  • TorchMetrics-compatible: .update() + .compute()
  • Supports batch inference, CPU/GPU, and customizable num_passes
  • Cleanly packaged and extensible for research or production

Supported Tasks

  • sequence-classification — e.g. distilbert-base-uncased-finetuned-sst-2-english
  • token-classification — e.g. dslim/bert-base-NER
  • question-answering — e.g. deepset/bert-base-cased-squad2
  • regression — e.g. roberta-base with a custom head

Note: Your model must contain dropout layers for MC sampling to work (most Hugging Face models do).


Installation

pip install dropwise-metrics

Or install from source:

git clone https://github.com/aryanator/dropwise-metrics.git
cd dropwise-metrics
pip install -e .

Example Usage (Metric Style)

Sequence Classification

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from dropwise_metrics.metrics.entropy import PredictiveEntropyMetric

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

metric = PredictiveEntropyMetric(model, tokenizer, task_type="sequence-classification", num_passes=20)
metric.update(["The movie was fantastic!", "Awful experience."])
results = metric.compute()

print(results[0])

Token Classification (NER)

from transformers import AutoModelForTokenClassification, AutoTokenizer
from dropwise_metrics.metrics.entropy import PredictiveEntropyMetric

model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")

metric = PredictiveEntropyMetric(model, tokenizer, task_type="token-classification", num_passes=15)
metric.update(["Hugging Face is based in New York City."])
results = metric.compute()

print(results[0]['token_predictions'])

Question Answering

from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from dropwise_metrics.metrics.entropy import PredictiveEntropyMetric

model = AutoModelForQuestionAnswering.from_pretrained("deepset/bert-base-cased-squad2")
tokenizer = AutoTokenizer.from_pretrained("deepset/bert-base-cased-squad2")

question = "Where is Hugging Face based?"
context = "Hugging Face Inc. is a company based in New York City."
qa_input = f"{question} [SEP] {context}"

metric = PredictiveEntropyMetric(model, tokenizer, task_type="question-answering", num_passes=10)
metric.update([qa_input])
results = metric.compute()

print(results[0]['answer'])

Regression

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from dropwise_metrics.metrics.entropy import PredictiveEntropyMetric

model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

metric = PredictiveEntropyMetric(model, tokenizer, task_type="regression", num_passes=20)
metric.update(["The child is very young."])
results = metric.compute()

print(results[0]['predicted_score'], "+/-", results[0]['uncertainty'])

Output Dictionary (per sample)

Common fields returned:

  • predicted_class: Most probable class (classification)
  • predicted_score: Scalar prediction (regression)
  • confidence: Highest softmax probability
  • entropy: Predictive entropy (lower = more confident)
  • std_dev: Per-class standard deviation
  • probs: Raw softmax probabilities
  • margin: Confidence gap between top-2 classes
  • answer: Predicted span (question answering only)
  • token_predictions: Per-token predictions (NER only)

Run Tests

python test_entropy.py

Folder Structure

dropwise_metrics/
├── base.py
├── metrics/
│   └── entropy.py
├── tasks/
│   ├── __init__.py
│   ├── sequence_classification.py
│   ├── token_classification.py
│   ├── question_answering.py
│   └── regression.py

License

MIT License

Built with ❤️ for robust, explainable, uncertainty-aware AI systems.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dropwise_metrics-0.1.1.tar.gz (7.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dropwise_metrics-0.1.1-py3-none-any.whl (7.7 kB view details)

Uploaded Python 3

File details

Details for the file dropwise_metrics-0.1.1.tar.gz.

File metadata

  • Download URL: dropwise_metrics-0.1.1.tar.gz
  • Upload date:
  • Size: 7.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.4

File hashes

Hashes for dropwise_metrics-0.1.1.tar.gz
Algorithm Hash digest
SHA256 abceabb26d56199ac37f2b2baae27325f5410bd4d34e0ec20bbde2c88f9bfe1e
MD5 ffcf29cc0c2db25bee9766da9d08d321
BLAKE2b-256 3fbecded66513bda26fa8c31f5ac469680d0c93e61409e46155af7cf1b0ae120

See more details on using hashes here.

File details

Details for the file dropwise_metrics-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for dropwise_metrics-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b4862ff10ba15bbac39984b80d5ec5fafca9f213902527be138461e1908bd54a
MD5 0225999e95331cb0eda8c213415efbdb
BLAKE2b-256 09dcbfd300f96017214cfea22aa5ce601c46298b1724a9f6c78c84dbea32f255

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page