Skip to main content

Sklearn models hyperparameters tuning using genetic algorithms

Project description

Build Status Codecov PyPI Version Python Version

Sklearn-genetic-opt

scikit-learn models hyperparameters tuning using evolutionary algorithms.

This is meant to be an alternative from popular methods inside scikit-learn such as Grid Search and Random Grid Search.

Sklearn-genetic-opt uses evolutionary algorithms from the deap package to find the "best" set of hyperparameters that optimizes (max or min) the cross validation scores, it can be used for both regression and classification problems.

Usage:

Install sklearn-genetic-opt

It's advised to install sklearn-genetic using a virtual env, inside the env use:

pip install sklearn-genetic-opt

Example

from sklearn_genetic import GASearchCV
from sklearn_genetic.utils import plot_fitness_evolution
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt


data = load_digits() 
y = data['target']
X = data['data'] 

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

clf = DecisionTreeClassifier()

evolved_estimator = GASearchCV(clf,
                               cv=3,
                               scoring='accuracy',
                               population_size=16,
                               generations=30,
                               tournament_size=3,
                               elitism=True,
                               crossover_probability=0.9,
                               mutation_probability=0.05,
                               continuous_parameters={'min_weight_fraction_leaf': (0, 0.5)},
                               categorical_parameters={'criterion': ['gini', 'entropy']},
                               integer_parameters={'max_depth': (2, 20), 'max_leaf_nodes': (2, 30)},
                               criteria='max',
                               n_jobs=-1,
                               verbose=True)

evolved_estimator.fit(X_train,y_train)
# Best parameters found
print(evolved_estimator.best_params)
# Use the model fitted with the best parameters
y_predict_ga = evolved_estimator.predict(X_test)
print(accuracy_score(y_test,y_predict_ga))

# See the evolution of the optimization per generation
plot_fitness_evolution(evolved_estimator)
plt.show()

# Saved metadata for further analysis
print(evolved_estimator.history)
print(evolved_estimator.logbook)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sklearn-genetic-opt-0.2.0.dev0.tar.gz (8.3 kB view details)

Uploaded Source

Built Distribution

sklearn_genetic_opt-0.2.0.dev0-py3-none-any.whl (9.9 kB view details)

Uploaded Python 3

File details

Details for the file sklearn-genetic-opt-0.2.0.dev0.tar.gz.

File metadata

  • Download URL: sklearn-genetic-opt-0.2.0.dev0.tar.gz
  • Upload date:
  • Size: 8.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.8

File hashes

Hashes for sklearn-genetic-opt-0.2.0.dev0.tar.gz
Algorithm Hash digest
SHA256 4d12afd138a22b0e4f7a691b07dbe3802254b45524e98d7e3c14acfbe2128c80
MD5 fd0231fc04e4570b79143b10aee92c26
BLAKE2b-256 4e57ebd3e02adfa2f918ce1a069539e6b04fad4bed611a01308d78eca85ce595

See more details on using hashes here.

File details

Details for the file sklearn_genetic_opt-0.2.0.dev0-py3-none-any.whl.

File metadata

  • Download URL: sklearn_genetic_opt-0.2.0.dev0-py3-none-any.whl
  • Upload date:
  • Size: 9.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.8

File hashes

Hashes for sklearn_genetic_opt-0.2.0.dev0-py3-none-any.whl
Algorithm Hash digest
SHA256 c2cc24acaefd1d2164ba79e27f4f92addb2831e731515c2fb404604a3687279a
MD5 55509eb999151ac9949c3b91301eba83
BLAKE2b-256 1369ab9f3053b72f4e255d5eddb0ed980b5d11fe8a2fe672696b9a4217d9d5f0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page