Skip to main content

Protected Classification package

Project description

python License: MIT

Protected classification

This library contains the Python implementation of Protected probabilistic classification. The method is way of protecting probabilistic prediction models against changes in the data distribution, concentrating on the case of classification. This is important in applications of machine learning, where the quality of a trained prediction algorithm may drop significantly in the process of its exploitation under the presence of various forms of dataset shift.

Installation

pip install protected-classification
conda install conda-forge::protected-classification

The algorithm can be applied on top of an underlying scikit-learn algorithm for binary and multiclass classification problems.

Usage

from protected_classification import ProtectedClassification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import log_loss
from sklearn.datasets import make_classification
import numpy as np

np.random.seed(1)

X, y = make_classification(n_samples=1000, n_classes=2, n_informative=10, random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
p_test = clf.predict_proba(X_test)

# Initialise Protected classification
pc = ProtectedClassification(estimator=clf)

# Calibrate test output probabilities
pc.fit(X_train, y_train)
p_prime = pc.predict_proba(X_test)

# Compare log loss of underlying RF algorithm and Protected classification
print('Underlying classifier log_loss (no dataset shift) ' + f'{log_loss(y_test, p_test):.3f}')
print('Protected classification log loss (no dataset shift) ' + f'{log_loss(y_test, p_prime):.3f}')

#  Assume a dataset shift where a random portion of the class labels is set to a single class
y_test[:100] = 0
ind = np.random.permutation(len(y_test))
X_test = X_test[ind]
y_test = y_test[ind]    

p_test = clf.predict_proba(X_test)

# Generate protected output probabilities  (assuming that test examples arrive sequentially)
pc = ProtectedClassification(estimator=clf)
p_prime = pc.predict_proba(X_test, y_test)

# Compare log loss of underlying RF algorithm and Protected classification
print('Underlying classifier log_loss (dataset shift) ' + f'{log_loss(y_test, p_test):.3f}')
print('Protected classification log loss (dataset shift) ' + f'{log_loss(y_test, p_prime):.3f}')

Examples

Further examples can be found in the github repository https://github.com/ip200/protected-calibration in the examples folder:

Citation

If you find this library useful please consider citing:

  • Vovk, Vladimir, Ivan Petej, and Alex Gammerman. "Protected probabilistic classification." In Conformal and Probabilistic Prediction and Applications, pp. 297-299. PMLR, 2021. (arxiv version https://arxiv.org/pdf/2107.01726.pdf)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

protected-classification-0.1.4.tar.gz (12.5 kB view details)

Uploaded Source

File details

Details for the file protected-classification-0.1.4.tar.gz.

File metadata

  • Download URL: protected-classification-0.1.4.tar.gz
  • Upload date:
  • Size: 12.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for protected-classification-0.1.4.tar.gz
Algorithm Hash digest
SHA256 67d9230391a0aeeee775685937d722ec480c8bc5056b9343b7e19587edbb85fd
MD5 7729f1a32fc7a9b08637570e03906b0c
BLAKE2b-256 9cde539e9a3ca1733abe43efc38c68877d4645a51ed724c08b634416ff1230f1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page