Skip to main content

Supervised multi-class/single-label classification with gradients

Project description

Documentation Status PyPI - Project MIT License https://raw.githubusercontent.com/RaoulHeese/casimac/master/docs/source/_static/simplex.png

This Python project provides a supervised multi-class classification algorithm with a focus on calibration, which allows the prediction of class labels and their probabilities including gradients with respect to features. The classifier is designed along the principles of an scikit-learn estimator.

The details of the algorithm have been published in PLOS ONE (preprint: arXiv:2103.02926).

Complete documentation of the code is available via https://casimac.readthedocs.io/en/latest/. Example notebooks can be found in the examples directory.

Installation

Install the package via pip or clone this repository. In order to use pip, type:

$ pip install casimac

Getting Started

Use the CASIMAClassifier class to create a classifier object. This object provides a fit method for training and a predict method for the estimation of class labels. Furthermore, the predict_proba method can be used to predict class label probabilities.

Below is a short example.

from casimac import CASIMAClassifier

import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
import matplotlib.pyplot as plt

# Create toy data
N = 10
seed = 42
X = np.random.RandomState(seed).uniform(-10,10,N).reshape(-1,1)
y = np.zeros(X.size)
y[X[:,0]>0] = 1

# Classify
clf = CASIMAClassifier(GaussianProcessRegressor)
clf.fit(X, y)

# Predict
X_sample = np.linspace(-10,10,100).reshape(-1,1)
y_sample = clf.predict(X_sample)
p_sample = clf.predict_proba(X_sample)

# Plot result
plt.figure(figsize=(8,3))
plt.plot(X_sample,y_sample,label="class prediction")
plt.plot(X_sample,p_sample[:,1],label="class probability prediction")
plt.scatter(X,y,c='r',label="train data")
plt.xlabel("X")
plt.ylabel("label / probability")
plt.legend()
plt.show()
https://raw.githubusercontent.com/RaoulHeese/casimac/master/docs/source/_static/plot.png

📖 Citation

If you find this code useful, please consider citing:

@article{10.1371/journal.pone.0279876,
      doi={10.1371/journal.pone.0279876},
      author={Heese, Raoul and Schmid, Jochen and Walczak, Micha{\l} and Bortz, Michael},
      journal={PLOS ONE},
      publisher={Public Library of Science},
      title={Calibrated simplex-mapping classification},
      year={2023},
      month={01},
      volume={18},
      url={https://doi.org/10.1371/journal.pone.0279876},
      pages={1-26},
      number={1}
      }

License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

casimac-1.2.4.tar.gz (13.7 kB view details)

Uploaded Source

Built Distribution

casimac-1.2.4-py3-none-any.whl (12.3 kB view details)

Uploaded Python 3

File details

Details for the file casimac-1.2.4.tar.gz.

File metadata

  • Download URL: casimac-1.2.4.tar.gz
  • Upload date:
  • Size: 13.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.0

File hashes

Hashes for casimac-1.2.4.tar.gz
Algorithm Hash digest
SHA256 74bda7e7be5c7731ac3df3386d3d76c3673192b57dccbbc26a6453ece77fcf52
MD5 13333853c02b768862ecac5972bd88d6
BLAKE2b-256 55160e75a13df17acf7b3a6939cd710e978d9258197b9d023fde3bd839d9edce

See more details on using hashes here.

File details

Details for the file casimac-1.2.4-py3-none-any.whl.

File metadata

  • Download URL: casimac-1.2.4-py3-none-any.whl
  • Upload date:
  • Size: 12.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.0

File hashes

Hashes for casimac-1.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 61dcdc312d534bd9718a34dc7bbfb274945ea1980fd2da87e3bb78bfd1a22169
MD5 108bbc0009d2e177c2a3915db90cf231
BLAKE2b-256 00dac83d9c2f6fbf276e95e91a65c51dad3463488933252d983f7839e89bd7a3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page