Skip to main content

Better insights into Machine Learning models performance

Project description

Welcome to modelsight

Better insights into Machine Learning models performance.

Modelsight is a collection of functions that create publication-ready graphics for the visual evaluation of binary classifiers adhering to the scikit-learn interface.

Modelsight is strongly oriented towards the evaluation of already fitted ExplainableBoostingClassifier of the interpretml package.

Overview
Code !pypi !python-versions

💫 Features

Our goal is to streamline the visual assessment of binary classifiers by creating a set of functions designed to generate publication-ready plots.

Module Status Links
Calibration maturing Tutorial · API Reference
Curves maturing Tutorial · API Reference

:eyeglasses: Install modelsight

  • Operating system: macOS X · Linux
  • Python version: 3.10 (only 64-bit)
  • Package managers: pip

pip

Using pip, modelsight releases are available as source packages and binary wheels. You can see all available wheels here.

$ pip install modelsight

:zap: Quickstart

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import StratifiedKFold, RepeatedStratifiedKFold
from interpret.glassbox import ExplainableBoostingClassifier

from utils import (
    select_features, 
    get_feature_selector, 
    get_model, 
    get_calibrated_model
)

from modelsight.curves import average_roc_curves
from modelsight.config import settings
from modelsight._typing import CVModellingOutput, Estimator

X, y = load_breast_cancer(return_X_y=True, as_frame=True)

outer_cv = RepeatedStratifiedKFold(n_repeats=cv_config.get("N_REPEATS"), 
                        n_splits=cv_config.get("N_SPLITS"), 
                        random_state=settings.misc.seed)
inner_cv = StratifiedKFold(n_splits=cv_config.get("N_SPLITS"), 
                        shuffle=cv_config.get("SHUFFLE"),
                        random_state=settings.misc.seed)

gts_train = []
gts_val = []
probas_train = []
probas_val = []
gts_train_conc = []
gts_val_conc = []
probas_train_conc = []
probas_val_conc = []

models = []
errors = []
correct = []
features = []

ts = datetime.timestamp(datetime.now())
cv_config = {
    "N_REPEATS": 10,
    "N_SPLITS": 10,
    "SHUFFLE": True,
    "SCALE": False,
    "CALIBRATE": True,
    "CALIB_FRACTION": 0.15,
}

for i, (train_idx, val_idx) in enumerate(outer_cv.split(X, y)):
    Xtemp, ytemp = X.iloc[train_idx, :], y.iloc[train_idx]
    Xval, yval = X.iloc[val_idx, :], y.iloc[val_idx]

    if cv_config.get("CALIBRATE"):
        Xtrain, Xcal, ytrain, ycal = train_test_split(Xtemp, ytemp, 
                                test_size=cv_config.get("CALIB_FRACTION"), 
                                stratify=ytemp, 
                                random_state=settings.misc.seed)
    else:
        Xtrain, ytrain = Xtemp, ytemp
    
    model = get_model(seed=settings.misc.seed)

    # select features
    feat_subset = select_features(Xtrain, ytrain, 
                                selector=get_feature_selector(settings.misc.seed), 
                                cv=inner_cv,
                                scale=False,
                                frac=1.25)
    features.append(feat_subset)

    if cv_config.get("SCALE"):
        numeric_cols = Xtrain.select_dtypes(include=[np.float64, np.int64]).columns.tolist()
        scaler = StandardScaler()
        Xtrain.loc[:, numeric_cols] = scaler.fit_transform(Xtrain.loc[:, numeric_cols])
        Xtest.loc[:, numeric_cols] = scaler.transform(Xtest.loc[:, numeric_cols])            

    model.fit(Xtrain[feat_subset], ytrain)

    if cv_config.get("CALIBRATE"):
        model = get_calibrated_model(model, 
                                    X=Xcal.loc[:, feat_subset],
                                    y=ycal)
    
    models.append(model)

    # accumulate ground-truths
    gts_train.append(ytrain)
    gts_val.append(yval)

    # accumulate predictions
    train_pred_probas = model.predict_proba(Xtrain)[:, 1]
    val_pred_probas = model.predict_proba(Xval)[:, 1]

    probas_train.append(train_pred_probas)
    probas_val.append(val_pred_probas)      
    
    # identify correct and erroneous predictions according to the 
    # classification cut-off that maximizes the Youden's J index
    fpr, tpr, thresholds = roc_curve(ytrain, train_pred_probas)
    idx = np.argmax(tpr - fpr)
    youden = thresholds[idx]

    labels_val = np.where(val_pred_probas >= youden, 1, 0)

    # indexes of validation instances misclassified by the model
    error_idxs = Xval[(yval != labels_val)].index
    errors.append(error_idxs)
    
    # indexes of correct predictions
    correct.append(Xval[(yval == labels_val)].index)
    
    
cv_results = CVModellingOutput(
    gts_train = np.array(gts_train),
    gts_val = np.array(gts_val),
    probas_train = np.array(probas_train),
    probas_val = np.array(probas_val),
    gts_train_conc = np.concatenate(gts_train),
    gts_val_conc = np.concatenate(gts_val),
    probas_train_conc = np.concatenate(probas_train),
    probas_val_conc = np.concatenate(probas_val),
    models = models,
    errors = np.array(errors),
    correct = np.array(correct),
    features = np.array(features)
)

# Plot the average receiver-operating characteristic (ROC) curve
model_names = ["EBM"]
alpha = 0.05
alph_str = str(alpha).split(".")[1]
alpha_formatted = f".{alph_str}"
roc_symbol = "*"
n_boot = 100
kwargs = dict()

f, ax = plt.subplots(1, 1, figsize=(8, 8))

f, ax, barplot, bars, aucs_cis = average_roc_curves(cv_results,
                                                    colors=palette,
                                                    model_keys=model_names,
                                                    show_ci=True,
                                                    n_boot=n_boot,
                                                    bars_pos=[
                                                        0.5, 0.01, 0.4, 0.1*len(model_names)],
                                                    random_state=settings.misc.seed,
                                                    ax=ax,
                                                    **kwargs)

:gift_heart: Contributing

Interested in contributing? Check out the contributing guidelines. Please note that this project is released with a Code of Conduct. By contributing to this project, you agree to abide by its terms.

🛣️ Roadmap

Features:

  • Calibration module to assess calibration of ML predicted probabilities via Hosmer-Lemeshow plot
  • Average Receiver-operating characteristic curves
  • Average Precision-recall curves
  • Feature importance (Global explanation)
  • Individualized explanations (Local explanation)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

modelsight-0.3.2.tar.gz (27.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

modelsight-0.3.2-py3-none-any.whl (27.3 kB view details)

Uploaded Python 3

File details

Details for the file modelsight-0.3.2.tar.gz.

File metadata

  • Download URL: modelsight-0.3.2.tar.gz
  • Upload date:
  • Size: 27.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for modelsight-0.3.2.tar.gz
Algorithm Hash digest
SHA256 2b02bb84381bb9df6887c1b81d40e6e889d83eb214d5f764e0a16190fe2c7d80
MD5 f0459b6d0fa94bff003a801850357a3f
BLAKE2b-256 013f753bd637c4c661b6d68286266347ce2a4f38ff4c433d7dd386cbfd2d8ec5

See more details on using hashes here.

File details

Details for the file modelsight-0.3.2-py3-none-any.whl.

File metadata

  • Download URL: modelsight-0.3.2-py3-none-any.whl
  • Upload date:
  • Size: 27.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for modelsight-0.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 19a0974bbf0605a1914abda7d1357ae04fc5556e6982136842bac0ce1a42bf3a
MD5 b333c7fc177b7367d759e9ba01bf5936
BLAKE2b-256 b1da077a40da98f6ae76c02ba1d6788b79154e650a3e3b4619ec810d57cfc96c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page