Supervised multi-class/single-label classification with gradients
Project description
This Python project provides a supervised multi-class classification algorithm with a focus on calibration, which allows the prediction of class labels and their probabilities including gradients with respect to features. The classifier is designed along the principles of an scikit-learn estimator.
The details of the algorithm have been published in PLOS ONE (preprint: arXiv:2103.02926).
Complete documentation of the code is available via https://casimac.readthedocs.io/en/latest/. Example notebooks can be found in the examples directory.
Installation
Install the package via pip or clone this repository. In order to use pip, type:
$ pip install casimac
Getting Started
Use the CASIMAClassifier class to create a classifier object. This object provides a fit method for training and a predict method for the estimation of class labels. Furthermore, the predict_proba method can be used to predict class label probabilities.
Below is a short example.
from casimac import CASIMAClassifier
import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
import matplotlib.pyplot as plt
# Create toy data
N = 10
seed = 42
X = np.random.RandomState(seed).uniform(-10,10,N).reshape(-1,1)
y = np.zeros(X.size)
y[X[:,0]>0] = 1
# Classify
clf = CASIMAClassifier(GaussianProcessRegressor)
clf.fit(X, y)
# Predict
X_sample = np.linspace(-10,10,100).reshape(-1,1)
y_sample = clf.predict(X_sample)
p_sample = clf.predict_proba(X_sample)
# Plot result
plt.figure(figsize=(8,3))
plt.plot(X_sample,y_sample,label="class prediction")
plt.plot(X_sample,p_sample[:,1],label="class probability prediction")
plt.scatter(X,y,c='r',label="train data")
plt.xlabel("X")
plt.ylabel("label / probability")
plt.legend()
plt.show()
📖 Citation
If you find this code useful, please consider citing:
@article{10.1371/journal.pone.0279876,
doi={10.1371/journal.pone.0279876},
author={Heese, Raoul and Schmid, Jochen and Walczak, Micha{\l} and Bortz, Michael},
journal={PLOS ONE},
publisher={Public Library of Science},
title={Calibrated simplex-mapping classification},
year={2023},
month={01},
volume={18},
url={https://doi.org/10.1371/journal.pone.0279876},
pages={1-26},
number={1}
}
License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file casimac-1.2.4.tar.gz
.
File metadata
- Download URL: casimac-1.2.4.tar.gz
- Upload date:
- Size: 13.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 74bda7e7be5c7731ac3df3386d3d76c3673192b57dccbbc26a6453ece77fcf52 |
|
MD5 | 13333853c02b768862ecac5972bd88d6 |
|
BLAKE2b-256 | 55160e75a13df17acf7b3a6939cd710e978d9258197b9d023fde3bd839d9edce |
File details
Details for the file casimac-1.2.4-py3-none-any.whl
.
File metadata
- Download URL: casimac-1.2.4-py3-none-any.whl
- Upload date:
- Size: 12.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 61dcdc312d534bd9718a34dc7bbfb274945ea1980fd2da87e3bb78bfd1a22169 |
|
MD5 | 108bbc0009d2e177c2a3915db90cf231 |
|
BLAKE2b-256 | 00dac83d9c2f6fbf276e95e91a65c51dad3463488933252d983f7839e89bd7a3 |