Skip to main content

A scikit-learn wrapper for HpBandSter hyper parameter search

Project description

hpbandster-sklearn

hpbandster-sklearn is a Python library providing a scikit-learn wrapper - HpBandSterSearchCV - for HpBandSter, a hyper parameter tuning library.

Motivation

HpBandSter implements several cutting-edge hyper parameter algorithms, including HyperBand and BOHB. They often outperform standard Random Search, finding best parameter combinations in less time.

HpBandSter is powerful and configurable, but its usage is often unintuitive for beginners and necessitating a large amount of boilerplate code. In order to solve that issue, HpBandSterSearchCV was created as a drop-in replacement for scikit-learn hyper parameter searchers, following its well-known and popular API, making it possible to tune scikit-learn API estimators with minimal setup.

HpBandSterSearchCV API has been based on scikit-learn's HalvingRandomSearchCV, implementing nearly all of the parameters it does.

Installation

pip install hpbandster-sklearn

Usage

Use it like any other scikit-learn hyper parameter searcher:

import numpy as np
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils.validation import check_is_fitted
from hpbandster_sklearn import HpBandSterSearchCV

X, y = load_iris(return_X_y=True)
clf = RandomForestClassifier(random_state=0)
np.random.seed(0)

param_distributions = {"max_depth": [2, 3, 4], "min_samples_split": list(range(2, 12))}

search = HpBandSterSearchCV(clf, param_distributions,random_state=0, n_jobs=1, n_iter=10, verbose=1).fit(X, y)
search.best_params_

You can also use ConfigSpace.ConfigurationSpace objects instead of dicts (in fact, it is recommended)!

import numpy as np
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils.validation import check_is_fitted
from hpbandster_sklearn import HpBandSterSearchCV
import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH

X, y = load_iris(return_X_y=True)
clf = RandomForestClassifier(random_state=0)
np.random.seed(0)

param_distributions = CS.ConfigurationSpace(seed=42)
param_distributions.add_hyperparameter(CSH.UniformIntegerHyperparameter("min_samples_split", 2, 11))
param_distributions.add_hyperparameter(CSH.UniformIntegerHyperparameter("max_depth", 2, 4))

search = HpBandSterSearchCV(clf, param_distributions,random_state=0, n_jobs=1, n_iter=10, verbose=1).fit(X, y)
search.best_params_

Please refer to the documentation of this library, as well as to the documentation of HpBandSter and ConfigSpace for more information.

Pipelines and TransformedTargetRegressor are also supported. Make sure to prefix the hyper parameter and resource names accordingly should you use either (or both) - for example, final_estimator__n_estimators. n_samples is not to be prefixed.

Early stopping

As almost every search algorithm in HpBandSter leverages early stopping (mostly through Successive Halving), the user can configure the resource and budget to be used through the arguments of HpBandSterSearchCV object.

search = HpBandSterSearchCV(
    clf,
    param_distributions,
    resource_name='n_samples', # can be either 'n_samples' or a string corresponding to an estimator attribute, eg. 'n_estimators' for an ensemble
    resource_type=float, # if specified, the resource value will be cast to that type before being passed to the estimator, otherwise it will be derived automatically
    min_budget=0.2,
    max_budget=1,
)

search = HpBandSterSearchCV(
    clf,
    param_distributions,
    resource_name='n_estimators', # can be either 'n_samples' or a string corresponding to an estimator attribute, eg. 'n_estimators' for an ensemble
    resource_type=int, # if specified, the resource value will be cast to that type before being passed to the estimator, otherwise it will be derived automatically
    min_budget=20,
    max_budget=200,
)

By default, the object will try to automatically determine the best resource, by checking the following in order:

  • 'n_estimators', if the model has that attribute and the warm_start attribute
  • 'max_iter', if the model has that attribute and the warm_start attribute
  • 'n_samples' - if the model doesn't support warm_start, the dataset samples will be used as the resource instead, meaing the model will be iteratively fitted on a bigger and bigger portion of the dataset.

Furthermore, special support has been added for LightGBM, XGBoost and CatBoost scikit-learn estimators.

Documentation

https://hpbandster-sklearn.readthedocs.io/en/latest/

References

Author

Antoni Baum (Yard1)

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hpbandster-sklearn-2.0.2.tar.gz (25.8 kB view details)

Uploaded Source

Built Distribution

hpbandster_sklearn-2.0.2-py3-none-any.whl (27.6 kB view details)

Uploaded Python 3

File details

Details for the file hpbandster-sklearn-2.0.2.tar.gz.

File metadata

  • Download URL: hpbandster-sklearn-2.0.2.tar.gz
  • Upload date:
  • Size: 25.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.15

File hashes

Hashes for hpbandster-sklearn-2.0.2.tar.gz
Algorithm Hash digest
SHA256 3525d256015a05621eee1632d30a70a9cf63d5fd370d35d7cf556ffd9e695efa
MD5 3420dbffcf96ef2fc7fbd9e1479029c7
BLAKE2b-256 96659b71c962bb4a07aace3aa95f3dab4d3fb26183ba3963750782de1abf4fb2

See more details on using hashes here.

File details

Details for the file hpbandster_sklearn-2.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for hpbandster_sklearn-2.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 881f2fd4e44857fe25106009c266ece7b14baae4cea289fafcd7e26e7295f83a
MD5 03229d7c6a9451e69cb724173cbd7e3f
BLAKE2b-256 f54e644c4ff018829cd0a8e15aef2f14d6f5fe94482b9c6bf0eb8edfbe89956e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page