Skip to main content

TorchMetrics-compatible predictive uncertainty metrics using MC Dropout

Project description

Dropwise-Metrics

Dropwise-Metrics is a lightweight TorchMetrics-compatible toolkit for performing Monte Carlo Dropout–based uncertainty estimation in Transformers. It enables confidence-aware decision making by revealing how certain a model is about its predictions — packaged as plug-and-play PyTorch Metric classes.


Features

  • Enable dropout during inference for Bayesian-like uncertainty estimation
  • Compute predictive entropy, confidence, and per-class standard deviation
  • Modular support for classification, QA, token tagging, and regression
  • Works seamlessly with Hugging Face Transformers and PyTorch
  • TorchMetrics-compatible: .update() + .compute()
  • Supports batch inference, CPU/GPU, and customizable num_passes
  • Cleanly packaged and extensible for research or production

Supported Tasks

  • sequence-classification — e.g. distilbert-base-uncased-finetuned-sst-2-english
  • token-classification — e.g. dslim/bert-base-NER
  • question-answering — e.g. deepset/bert-base-cased-squad2
  • regression — e.g. roberta-base with a custom head

Note: Your model must contain dropout layers for MC sampling to work (most Hugging Face models do).


Installation

pip install dropwise-metrics

Or install from source:

git clone https://github.com/aryanator/dropwise-metrics.git
cd dropwise-metrics
pip install -e .

Example Usage (Metric Style)

Sequence Classification

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from dropwise_metrics.metrics.entropy import PredictiveEntropyMetric

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

metric = PredictiveEntropyMetric(model, tokenizer, task_type="sequence-classification", num_passes=20)
metric.update(["The movie was fantastic!", "Awful experience."])
results = metric.compute()

print(results[0])

Token Classification (NER)

from transformers import AutoModelForTokenClassification, AutoTokenizer
from dropwise_metrics.metrics.entropy import PredictiveEntropyMetric

model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")

metric = PredictiveEntropyMetric(model, tokenizer, task_type="token-classification", num_passes=15)
metric.update(["Hugging Face is based in New York City."])
results = metric.compute()

print(results[0]['token_predictions'])

Question Answering

from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from dropwise_metrics.metrics.entropy import PredictiveEntropyMetric

model = AutoModelForQuestionAnswering.from_pretrained("deepset/bert-base-cased-squad2")
tokenizer = AutoTokenizer.from_pretrained("deepset/bert-base-cased-squad2")

question = "Where is Hugging Face based?"
context = "Hugging Face Inc. is a company based in New York City."
qa_input = f"{question} [SEP] {context}"

metric = PredictiveEntropyMetric(model, tokenizer, task_type="question-answering", num_passes=10)
metric.update([qa_input])
results = metric.compute()

print(results[0]['answer'])

Regression

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from dropwise_metrics.metrics.entropy import PredictiveEntropyMetric

model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

metric = PredictiveEntropyMetric(model, tokenizer, task_type="regression", num_passes=20)
metric.update(["The child is very young."])
results = metric.compute()

print(results[0]['predicted_score'], "+/-", results[0]['uncertainty'])

Output Dictionary (per sample)

Common fields returned:

  • predicted_class: Most probable class (classification)
  • predicted_score: Scalar prediction (regression)
  • confidence: Highest softmax probability
  • entropy: Predictive entropy (lower = more confident)
  • std_dev: Per-class standard deviation
  • probs: Raw softmax probabilities
  • margin: Confidence gap between top-2 classes
  • answer: Predicted span (question answering only)
  • token_predictions: Per-token predictions (NER only)

Run Tests

python test_entropy.py

Folder Structure

dropwise_metrics/
├── base.py
├── metrics/
│   └── entropy.py
├── tasks/
│   ├── __init__.py
│   ├── sequence_classification.py
│   ├── token_classification.py
│   ├── question_answering.py
│   └── regression.py

License

MIT License

Built with ❤️ for robust, explainable, uncertainty-aware AI systems.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dropwise_metrics-0.1.0.tar.gz (7.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dropwise_metrics-0.1.0-py3-none-any.whl (7.7 kB view details)

Uploaded Python 3

File details

Details for the file dropwise_metrics-0.1.0.tar.gz.

File metadata

  • Download URL: dropwise_metrics-0.1.0.tar.gz
  • Upload date:
  • Size: 7.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.4

File hashes

Hashes for dropwise_metrics-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c3ad7340b105abfc86d3a60d2ad6ac56082488105082332129348eeaee12366e
MD5 40e40d0e519eb83462fb707af707b771
BLAKE2b-256 5b9685e67f49860966c34bb69e2c6a27946a7cee38e834cdbda2f254a9cc13f0

See more details on using hashes here.

File details

Details for the file dropwise_metrics-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for dropwise_metrics-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 34be830a07fa3e32fc415e16618a69afd865ecfe1e23be988df58d41e0da84e6
MD5 ad8ef85b81bc6ec5fc1cd01d8940e286
BLAKE2b-256 09b67d3ef706e3b54c5a8bb40580c685dcfeb01860c400e53fd0f91b07db66d7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page