Skip to main content

Generating evaluating metrics reports for machine learning models in two lines of code.

Project description

Machine Learning Report Toolkit

A plug-in to generate various evaluation metrics and reports ( PR-curves, classifications reports, confusion matrix) for supervised machine learning models using only two lines of code.

from ml_report import MLReport

report = MLReport(y_true_label, y_pred_label, y_pred_prob, class_names)
report.run(results_path="results")

This will generate a classifier report, containing the following information:

  • A classification report with precision, recall and F1.
  • A visualization of the precision and recall curves as a function of the threshold for each class.
  • A confusion matrix.
  • A .csv file with precision, recall, at different thresholds.
  • A .csv file with predictions scores for each class for each sample.

All this information is saved in the results folder under different filenames, containing both images, .csv files, and a .txt file with the classification report.

Precision x Recall vs Threshold Confusion Matrix

                          precision    recall  f1-score   support

             alt.atheism       0.81      0.87      0.84       159
           comp.graphics       0.65      0.81      0.72       194
 comp.os.ms-windows.misc       0.81      0.82      0.81       197
comp.sys.ibm.pc.hardware       0.75      0.75      0.75       196
   comp.sys.mac.hardware       0.86      0.78      0.82       193
          comp.windows.x       0.81      0.81      0.81       198
            misc.forsale       0.74      0.86      0.80       195
               rec.autos       0.92      0.90      0.91       198
         rec.motorcycles       0.95      0.96      0.95       199
      rec.sport.baseball       0.94      0.92      0.93       198
        rec.sport.hockey       0.96      0.97      0.96       200
               sci.crypt       0.95      0.89      0.92       198
         sci.electronics       0.85      0.81      0.83       196
                 sci.med       0.90      0.90      0.90       198
               sci.space       0.94      0.91      0.93       197
  soc.religion.christian       0.90      0.92      0.91       199
      talk.politics.guns       0.86      0.88      0.87       182
   talk.politics.mideast       0.97      0.95      0.96       188
      talk.politics.misc       0.86      0.82      0.84       155
      talk.religion.misc       0.82      0.57      0.67       126

                accuracy                           0.86      3766
               macro avg       0.86      0.86      0.86      3766
            weighted avg       0.86      0.86      0.86      3766

Example: running ML-Report-Toolkit on cross-fold classification

Install the package and dependencies:

pip install ml-report-kit
pip install scikit-learn

Run the following code:

import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression

from ml_report_kit import MLReport

dataset = fetch_20newsgroups(subset='all', shuffle=True, random_state=42)
k_folds = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
folds = {}

for fold_nr, (train_index, test_index) in enumerate(k_folds.split(dataset.data, dataset.target)):
    x_train, x_test = np.array(dataset.data)[train_index], np.array(dataset.data)[test_index]
    y_train, y_test = np.array(dataset.target)[train_index], np.array(dataset.target)[test_index]
    folds[fold_nr] = {"x_train": x_train, "x_test": x_test, "y_train": y_train, "y_test": y_test}

for fold_nr in folds.keys():
    clf = Pipeline([('tfidf', TfidfVectorizer()), ('clf', LogisticRegression(class_weight='balanced'))])
    clf.fit(folds[fold_nr]["x_train"], folds[fold_nr]["y_train"])
    y_pred = clf.predict(folds[fold_nr]["x_test"])
    y_pred_prob = clf.predict_proba(folds[fold_nr]["x_test"])
    y_true_label = [dataset.target_names[sample] for sample in folds[fold_nr]["y_test"]]
    y_pred_label = [dataset.target_names[sample] for sample in y_pred]
    
    report = MLReport(y_true_label, y_pred_label, y_pred_prob, dataset.target_names)
    report.run(results_path="results", fold_nr=fold_nr)

This will generate, for each fold, the reports and metrics mentioned above, in the reports folder. For each fold there will be the following files:

  • classification_report.txt
  • confusion_matrix.png
  • confusion_matrix.txt
  • predictions_scores.csv
  • For each class:
    • precision_recall_threshold_<class_name>.csv
    • precision_recall_threshold_<class_name>.png

License

Apache License 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ml_report_kit-0.1.4.tar.gz (10.2 kB view details)

Uploaded Source

Built Distribution

ml_report_kit-0.1.4-py3-none-any.whl (9.8 kB view details)

Uploaded Python 3

File details

Details for the file ml_report_kit-0.1.4.tar.gz.

File metadata

  • Download URL: ml_report_kit-0.1.4.tar.gz
  • Upload date:
  • Size: 10.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for ml_report_kit-0.1.4.tar.gz
Algorithm Hash digest
SHA256 a44bfc98a2e9c384ea282e8a47581bb667ced1c7e716a7dde012e76368fbd8ea
MD5 65139a738a22ee552f632c76472d5619
BLAKE2b-256 a49b6aad60ee81d85e2b7bcfb8eb6b7436bb502697f42a856c9e4fd613083f2e

See more details on using hashes here.

File details

Details for the file ml_report_kit-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for ml_report_kit-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 0d6aeff6ed222865f2e8096db537e755bb56ef46e0a48ccb39915a20ca515096
MD5 97617990e06cf21ebf9653f9737627f3
BLAKE2b-256 e8769f85a11837ff90bd2f666a6d2082eaca6c34f0e23c46135584fd34d637e7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page