Skip to main content

Sklearn models hyperparameters tuning using genetic algorithms

Project description

Build Status Codecov PyPI Version Python Version

Sklearn-genetic-opt

scikit-learn models hyperparameters tuning using evolutionary algorithms.

This is meant to be an alternative from popular methods inside scikit-learn such as Grid Search and Random Grid Search.

Sklearn-genetic-opt uses evolutionary algorithms from the deap package to find the "best" set of hyperparameters that optimizes (max or min) the cross validation scores, it can be used for both regression and classification problems.

Usage:

Install sklearn-genetic-opt

It's advised to install sklearn-genetic using a virtual env, inside the env use:

pip install sklearn-genetic-opt

Example

from sklearn_genetic import GASearchCV
from sklearn_genetic.utils import plot_fitness_evolution
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt


data = load_digits() 
y = data['target']
X = data['data'] 

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

clf = DecisionTreeClassifier()

evolved_estimator = GASearchCV(clf,
                               cv=3,
                               scoring='accuracy',
                               population_size=16,
                               generations=30,
                               tournament_size=3,
                               elitism=True,
                               crossover_probability=0.9,
                               mutation_probability=0.05,
                               continuous_parameters={'min_weight_fraction_leaf': (0, 0.5)},
                               categorical_parameters={'criterion': ['gini', 'entropy']},
                               integer_parameters={'max_depth': (2, 20), 'max_leaf_nodes': (2, 30)},
                               criteria='max',
                               n_jobs=-1,
                               verbose=True)

evolved_estimator.fit(X_train,y_train)
# Best parameters found
print(evolved_estimator.best_params)
# Use the model fitted with the best parameters
y_predict_ga = evolved_estimator.predict(X_test)
print(accuracy_score(y_test,y_predict_ga))

# See the evolution of the optimization per generation
plot_fitness_evolution(evolved_estimator)
plt.show()

# Saved metadata for further analysis
print(evolved_estimator.history)
print(evolved_estimator.logbook)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sklearn-genetic-opt-0.2.0.tar.gz (8.3 kB view details)

Uploaded Source

Built Distribution

sklearn_genetic_opt-0.2.0-py3-none-any.whl (9.8 kB view details)

Uploaded Python 3

File details

Details for the file sklearn-genetic-opt-0.2.0.tar.gz.

File metadata

  • Download URL: sklearn-genetic-opt-0.2.0.tar.gz
  • Upload date:
  • Size: 8.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.8

File hashes

Hashes for sklearn-genetic-opt-0.2.0.tar.gz
Algorithm Hash digest
SHA256 354d70bdf7c770d71397af8e4fbbe1d5f0d6e2d06098071b77bd80c91bfd0954
MD5 faed9050b4f168ca8ce9c300079cb2e8
BLAKE2b-256 a05e34a922a212390114ad91b9d79465f8c879477dfd7f3e65f55255ebcbe22e

See more details on using hashes here.

File details

Details for the file sklearn_genetic_opt-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: sklearn_genetic_opt-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 9.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.8

File hashes

Hashes for sklearn_genetic_opt-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4795521d37b287dd4c44046983bc06ba1972620128ec5e2bcbce81f114b1bc9d
MD5 259fac404dcfa15a5549684654c462eb
BLAKE2b-256 61be733e026f31e7635e36274f1ffa03ed52f498d24a8cf5b93e9b6c7804d74a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page