Skip to main content

A high level library gridsearch / cross evaluation library for scikit-learn

Project description

Build Status codecov Checked with mypy

Easy Grid Search / Cross Validation

From data to score in 4 lines of code.

This library allows you to quickly train machine learning classifiers by automatically splitting the dataset and using both grid search and cross validation in the training process. Users can either pass define the parameters themselves or let the GSCV object choose them automatically (based on the classifier).

This library is an extension of the scikit-learn project.

View on pypi

Example:

from sklearn.neural_network import MLPClassifier
from sklearn import datasets
from easy_gscv.classifiers import GSCV

# Create test dataset
iris = datasets.load_iris()
X = iris.data
y = iris.target
clf = MLPClassifier()

# Create model instance
gscv_model = GSCV(clf(), X, y)

# Get score
gscv_model.score()

install

requires python 3.7+

pip install easy-gscv

create

from easy_gscv.models import GSCV
clf = LogisticRegression()
gscv_model = GSCV(
    clf(), X, y, cv=15, n_jobs=-1, params={
        'C': [10, 100],
        'penalty': ['l2']
    }
)

No need to create separate train / test datasets, the model does this automatically on initialization. If no parameters are provided the grid search is performed on a default set. But these can be overridden.

The number of folds to be used for cross validation can be specified by using the cv keyword. To speed up the training process you can use the n_jobs parameter to set the number of cpu cores to use (or set it to -1 to use all available.)

The model accepts either sklearn classifiers or string values. You can get a list of valid classifiers by calling the 'classifiers' property. Passing string arguments to the GSCV object in turn saves you from having to import sklearn classifiers yourself.

gscv_model = GSCV('RandomForestClassifier',, X, y)
gscv_model.classifiers

'KNeighborsClassifier',
'RandomForestClassifier',
'GradientBoostingClassifier',
'MLPClassifier',
'LogisticRegression',

score

gscv_model.score()

The grid search is performed on the training data. Use the score method to evaluate how well the model can be generalized by scoring it against the test dataset.

get_best_estimator

gscv_model.get_best_estimator()

Returns the best scoring sklearn classifier (based on training data). As its a valid scikit-learn classifier, you can use it do anything that you could do with sklearn classifier.

The following classifiers are currently supported. With the eventual goal of supporting all scikit-learn classifiers in the future.

  • KNeighborsClassifier
  • RandomForestClassifier
  • GradientBoostingClassifier
  • MLPClassifier
  • LogisticRegression

get_fit_details

As cross validation returns an average, it can be helpful to get a more detailed overview of the best scoring classifier.

This method returns a table like the one displayed below, which then can be used to further refine the choice or parameters for subsequent runs.

clf = KNeighborsClassifier()
gscv_model = GSCV(clf(), X, y)
gscv_model.get_fit_details()

0.965 (+/-0.026) for {'weights': 'uniform', 'n_neighbors': 3}
0.977 (+/-0.013) for {'weights': 'distance', 'n_neighbors': 3}
0.979 (+/-0.011) for {'weights': 'uniform', 'n_neighbors': 5}
0.979 (+/-0.011) for {'weights': 'distance', 'n_neighbors': 5}
0.976 (+/-0.018) for {'weights': 'uniform', 'n_neighbors': 8}
0.975 (+/-0.018) for {'weights': 'distance', 'n_neighbors': 8}
0.971 (+/-0.022) for {'weights': 'uniform', 'n_neighbors': 12}
0.973 (+/-0.024) for {'weights': 'distance', 'n_neighbors': 12}
0.973 (+/-0.025) for {'weights': 'uniform', 'n_neighbors': 15}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easy-gscv-0.2.tar.gz (7.5 kB view details)

Uploaded Source

Built Distribution

easy_gscv-0.2-py3-none-any.whl (8.6 kB view details)

Uploaded Python 3

File details

Details for the file easy-gscv-0.2.tar.gz.

File metadata

  • Download URL: easy-gscv-0.2.tar.gz
  • Upload date:
  • Size: 7.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.19.1 setuptools/40.0.0 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.7.0

File hashes

Hashes for easy-gscv-0.2.tar.gz
Algorithm Hash digest
SHA256 a891a245d833dcc50008481567fbded8de44ac620bbaf08d7101dddfe7d67bf9
MD5 e0019ce3718593c81cacbbecd25177ac
BLAKE2b-256 75e1cb1ed9642ca5803bfc563268d6ed3afbfbba240e92e476f4ac8a489643a8

See more details on using hashes here.

File details

Details for the file easy_gscv-0.2-py3-none-any.whl.

File metadata

  • Download URL: easy_gscv-0.2-py3-none-any.whl
  • Upload date:
  • Size: 8.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.19.1 setuptools/40.0.0 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.7.0

File hashes

Hashes for easy_gscv-0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2fea237005252cce3621a2aa19e8dd5aebd63fa7af572d606904a4dea1f64e12
MD5 b5dc20a142213263a075565ef623e0ae
BLAKE2b-256 69d3902c436f345fa188e2edc5f8670a91d5ed28c5cc9a7d748761505f88e92e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page