The river reliability diagram provides insight into classifier calibration with a visualization and metric based on the posterior balanaced accuracy.
Project description
River reliability
Install
Install the package with:
pip install riverreliability
How to use
Below, we show some basic funtionality of the package. Please look at the notebooks for more examples and documentation.
np.random.seed(42)
We start of by generating a fake dataset for classification and splitting it in a train and test set.
X, y = sklearn.datasets.make_classification(n_samples=5000, n_features=12, n_informative=3, n_classes=3)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X, y, test_size=0.2, shuffle=True)
For this example we use an SVM. We fit it on the training data and generate probabilities for the test set.
logreg = sklearn.svm.SVC(probability=True)
logreg.fit(X_train, y_train)
y_probs = logreg.predict_proba(X_test)
As a sanity check we compute some performance metrics.
print(f"Accuracy: {sklearn.metrics.accuracy_score(y_test, y_probs.argmax(axis=1))}")
print(f"Balanced accuracy: {sklearn.metrics.balanced_accuracy_score(y_test, y_probs.argmax(axis=1))}")
Accuracy: 0.808
Balanced accuracy: 0.8084048918146675
To get an insight into calibration we can look at the posterior reliability diagrams and the PEACE metric.
We can plot the diagrams aggregated over all classes:
ax = riverreliability.plots.river_reliability_diagram(y_probs.max(axis=1), y_probs.argmax(axis=1), y_test, bins="fd")
peace_metric = riverreliability.metrics.peace(y_probs.max(axis=1), y_probs.argmax(axis=1), y_test)
ax.set_title(f"PEACE: {peace_metric:.4f}")
_ = ax.legend()
Or class-wise to spot miscalibrations for particular classes:
import matplotlib.pyplot as plt
axes = riverreliability.plots.class_wise_river_reliability_diagram(y_probs, y_probs.argmax(axis=1), y_test, bins=15)
peace_metric = riverreliability.metrics.class_wise_error(y_probs, y_probs.argmax(axis=1), y_test, base_error=riverreliability.metrics.peace)
_ = plt.suptitle(f"PEACE: {peace_metric:.4f}")
In this particular example we can see that the classifier is well calibrated.
See the notebooks directory for more examples.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file riverreliability-0.1.2.tar.gz
.
File metadata
- Download URL: riverreliability-0.1.2.tar.gz
- Upload date:
- Size: 17.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1.post20200323 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.8.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e3de17f0cc1bf0b5e1947970ff5383375552b1835a8c7d89cb9f463401b460ba |
|
MD5 | 2aba27c4a92a2f0b6a5e8755b354da6a |
|
BLAKE2b-256 | 006ed9cdfa4c7073887d37dc3faa11f785f38ded0ae1e672e26db77f9eb13286 |
File details
Details for the file riverreliability-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: riverreliability-0.1.2-py3-none-any.whl
- Upload date:
- Size: 16.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1.post20200323 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.8.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 260aad1aa8808a4d6e33746e308a16490870430432d95d9a8fe273066b21e454 |
|
MD5 | 6468f28cf4c47011b058d8ee22f60793 |
|
BLAKE2b-256 | adcf34c6a9c53d08be8051bfe48d0caaea7d43c26fb4e2748092a2267aac1b0e |