Skip to main content

skflex provides a suite of flexible utility functions for use with the sklearn library

Project description

skflex

GitHub PyPI Build Status Tweet

FLEXIBLE FUNCTIONS ----- FAST PROCESSING AND EVALUATION

skflex provides a suite of utility functions for use with the sklearn library. The module primarily focuses on producing typical plots and metrics for evaluating machine learning models. It has been designed with flexability and customisation in mind to speed up workflows, and enhance comparative evaluation. Refer to the Github repository for example outputs and images.

Installation and Import

pip install skflex

import skflex.skflex as skf

ROC_AUC Curve

roc_auc_plot(models, X_test = None, y_test = None, width = 14, height = 12, legend_size = 14, title='ROC Curve')

Accepts fitted model(s) and test data. It will then:

  • Calculate ROC
  • Calculate AUC
  • Plot ROC curve with AUC provided in the legend

Parameters:

  • models - fitted model objects. NOTE: Only models with a 'predict_proba' or 'decision_function' method are supported.
  • X_test - test feature set
  • y_test - test labels set
  • title - title for ROC curve plot
  • width - plot width
  • height - plot height
  • legend_size - size of plot legend

Example:

from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression

model_1 = GaussianNB()
model_2 = LogisticRegression()

model_1.fit(X_train, y_train)
model_2.fit(X_train, y_train)

skf.roc_auc_plot(model_1, model_2, X_test = X_test, y_test = y_test, 
                title = 'Example ROC plot')

Classification Reports

classifier_train_report(models, X_train = None, y_train = None, X_test = None, y_test = None, scoring = 'accuracy', title = 'Reports')

Accepts classifier models, training data, and test data. It will then:

  • Fit the model(s) to training data
  • Make predictions using test data
  • Produce classification report for comparison
  • Produce confusion matrix for comparison
  • Provide an ordered summary (ranked best to worst score) using given evaluation metric

Parameters:

  • models - model objects to be trained and evaluated
  • X_train - training feature set
  • y_train - training label set
  • X_test - test feature set
  • y_test - test label set
  • scoring - summary evaluation metric. Common classifier evaluation metrics including accuracy, f1, precision, and recall are supported. Refer to sklean scoring documentation for more information. Scoring methodologies should be passed as strings, for example, precision would be passed as scoring = 'precision'
  • title - title for output report

Example:

from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression

model_1 = GaussianNB()
model_2 = LogisticRegression()

skf.classifier_train_report(model_1, model_2, X_train = X_train, y_train = y_train, 
                            X_test = X_test, 
                            y_test = y_test, 
                            scoring = 'accuracy', 
                            title = 'Example Report')

Validation Curve

validation_plot(model = None, param = None, param_grid = None, X_train = None, y_train = None, cv = 5, scoring = 'accuracy', width = 9, height = 9, title = 'Validation Curve')

Accepts a model, a related hyper-parameter, a list of hyper-parameter values, training and test data, number of cross-validation folds, scoring methodology, as well as a plot title. It will produce a plot of the validation curve for the training and test data using the mean scores and standard deviations obtained through the cross-validation process.

Parameters:

  • model - model object
  • param - hyperparameter to be used to produce the validation curve
  • param_grid - hyperparameter values to be tested
  • X_train - training feature set
  • y_train - training label set
  • cv - number of cross-validation folds
  • scoring - scoring methodology used during cross-validation process
  • title - title for validation plot
  • width - plot width
  • height - plot height

Example:

from sklearn.tree import DecisionTreeClassifier

model_1 = DecisionTreeClassifier()
params = [5, 10, 15, 20, 30, 40, 50]

skf.validation_plot(model = model_1, param = 'max_depth', param_grid = params, 
                    X_train = X_train, 
                    y_train = y_train, 
                    title = 'Example Validation Curve')

Train, Validation and Test Split

train_val_test(data = None, class_labels = None, train = 0.6, val = 0.2, shuffle = True, random_state = None)

Accepts a Pandas dataframe and will return a training, validation, and test set. Operates in a similar fashion to the sklearn train_test_split function by defining a percentage split for the training and validation sets (example 0.6 = 60%). The remainder is allocated to the test set.

Parameters:

  • data - dataframe to be split into a training, validation, and test set
  • class_labels - column in the dataframe containing class labels
  • train - percentage of data to be allocated to the training set
  • val - percentage of data to be allocated to the validation set
  • shuffle - if true, will shuffle the rows in the dataframe before splitting
  • random_state - if shuffle is ture, will set a random seed so that ordering produced by shuffle can be reproduced

Returns: X_train, y_train, X_val, y_val, X_test, y_test

Example:

X_train, y_train, X_val, y_val, X_test, y_test = skf.train_val_test(data = my_data, 
                                                                    class_labels = 'labels', 
                                                                    train = 0.6, 
                                                                    val = 0.2)

PCA Scree Plot

pca_scree_plot(data = None, n_components = None, width = 16, height = 10, legend_size = 12, scale_data = False, title = 'PCA Scree Plot')

Accepts data (array/dataframe), and number of principal components to be analysed. It will produce a scree plot of the cumulative variance explained.

Parameters:

  • data - dataset to be analysed
  • n_components - number of principal components to be analysed
  • width - width of plot
  • height - height of plot
  • legend_size - size of plot legend
  • scale_data - normalises data before analysis and plotting. If the data being passed has not yet been normalised, this parameter should be set as scale_data = True
  • title - plot title

Example:

from sklearn.preprocessing import scale

n_data = scale(my_data)

skf.pca_scree_plot(data = n_data, n_components = 70, title = 'Example PCA Scree Plot')

Requirements

  • Sklearn >= 0.24.1
  • Matplotlib >= 3.3.4
  • Pandas >= 1.2.4
  • Numpy >= 1.20.1

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skflex-1.0.2.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

skflex-1.0.2-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file skflex-1.0.2.tar.gz.

File metadata

  • Download URL: skflex-1.0.2.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.8

File hashes

Hashes for skflex-1.0.2.tar.gz
Algorithm Hash digest
SHA256 44c6cb69b170439477bbde6fa7308a89a8f9ec23ce8fb2513b2029abcdfa817f
MD5 c55d68e68cab8c1b788de768e92cb76e
BLAKE2b-256 15f5984c0980db52ffa8e819f77ca1b1b71cbe5b1477f781a8c6f8d935a61ec3

See more details on using hashes here.

File details

Details for the file skflex-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: skflex-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 6.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.8

File hashes

Hashes for skflex-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 bc9e886844b03693ce8c4725038c0bd45b9b37de8b2ce5dcd5f3d427b7b5756e
MD5 fb69c07b80d1b80ba933196391c5d6ff
BLAKE2b-256 e34e7d04a6372774a9206860c300e18d0f799f5dea9b54943eced63aa7a07669

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page