Grid search hyper-parameter optimization using a validation set (not cross validation)
Project description
A Python machine learning package for grid search hyper-parameter optimization using a validation set (defaults to cross validation when no validation set is available). This package works for Python 2.7+ and Python 3+, for any model (classification and regression), and runs in parallel on all threads on your CPU automatically.
scikit-learn provides a package for grid-search hyper-parameter optimization **using cross-validation** on the training dataset. Unfortunately, cross-validation is impractically slow for large datasets and fails for small datasets due to the lack of data in each class needed to properly train each fold. Instead, we use a constant validation set to optimize hyper-parameters – the hypopt package makes this fast (distributed on all CPU threads) and easy (one line of code).
hypopt.model_selection.fit_model_with_grid_search supports grid search hyper-parameter optimization when you already have a validation set , eliminating the extra hours of training time required when using cross-validation. However, when no validation set is given, it defaults to using cross-validation on the training set. This allows you to alows use hypopt anytime you need to do hyper-parameter optimization with grid-search, regardless of whether you use a validation set or cross-validation.
Installation
Python 2.7, 3.4, 3.5, and 3.6 are supported.
Stable release:
$ pip install hypopt
Developer (unstable) release:
$ pip install git+https://github.com/cgnorthcutt/hypopt.git
To install the codebase (enabling you to make modifications):
$ conda update pip # if you use conda
$ git clone https://github.com/cgnorthcutt/hypopt.git
$ cd hypopt
$ pip install -e .
Examples
Basic usage
# Assuming you already have train, test, val sets and a model.
from hypopt import GridSearch
param_grid = [
{'C': [1, 10, 100], 'kernel': ['linear']},
{'C': [1, 10, 100], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']},
]
# Grid-search all parameter combinations using a validation set.
gs = GridSearch(model = SVR(), param_grid = param_grid)
gs.fit(X_train, y_train, X_val, y_val)
print('Test Score for Optimized Parameters:', gs.score(X_test, y_test))
Choosing the scoring metric to optimize
The default metric is the the model.score() function, so in the previous example SVR().score() is optimized, which defaults to accuracy.
It’s easy to use a different scoring metric using the scoring parameter in hypopt.GridSearch.fit():
# This will use f1 score as the scoring metric that you optimize.
gs.fit(X_train, y_train, X_val, y_val, scoring='f1')
For classification, hypopt supports these string-named metrics: ‘accuracy’, ‘brier_score_loss’, ‘average_precision’, ‘f1’, ‘f1_micro’, ‘f1_macro’, ‘f1_weighted’, ‘neg_log_loss’, ‘precision’, ‘recall’, or ‘roc_auc’.
For regression, hypopt supports: “explained_variance”, “neg_mean_absolute_error”, “neg_mean_squared_error”, “neg_mean_squared_log_error”, “neg_median_absolute_error”, “r2”.
You can also create your own metric your_custom_score_func(y_true, y_pred) by wrapping it into an object using sklearn.metrics.make_scorer like:
from sklearn.metrics import make_scorer
scorer = make_scorer(your_custom_scoring_func)
opt.fit(X_train, y_train, X_val, y_val, scoring=scorer)
Minimal working examples
Other Examples including a working example with MNIST
Use hypopt with any model (PyTorch, Tensorflow, caffe2, scikit-learn, etc.)
All of the features of the hypopt package work with any model. Yes, any model. Feel free to use PyTorch, Tensorflow, caffe2, scikit-learn, mxnet, etc. If you use a scikit-learn model, all hypopt methods will work out-of-the-box. It’s also easy to use your favorite model from a non-scikit-learn package, just wrap your model into a Python class that inherets the sklearn.base.BaseEstimator. Here’s an example for a generic classifier:
from sklearn.base import BaseEstimator
class YourModel(BaseEstimator): # Inherits sklearn base classifier
def __init__(self, ):
pass
def fit(self, X, y, sample_weight = None):
pass
def predict(self, X):
pass
def score(self, X, y, sample_weight = None):
pass
# Inherting BaseEstimator gives you these for free!
# So if you inherit, there's no need to implement these.
def get_params(self, deep = True):
pass
def set_params(self, **params):
pass
PyTorch MNIST CNN Example
Check out a PyTorch MNIST CNN wrapped in the above class here. You use any object instantion of this class with hypopt just as you would any scikit-learn model. Another example of a fully compliant class is the LearningWithNoisyLabels() model.
If you don’t wish to write this code yourself, there are existing packages to do this for you. For PyTorch, check out the skorch Python package <https://skorch.readthedocs.io/en/stable/> which will wrap your pytorch model into a scikit-learn compliant model.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file hypopt-1.0.9.tar.gz
.
File metadata
- Download URL: hypopt-1.0.9.tar.gz
- Upload date:
- Size: 11.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.19.1 setuptools/40.2.0 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.7.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d5ce7032668c7e71f63dcee547dfbc2401931879a0b8b70847f084d68f25933f |
|
MD5 | 5903ea6d85ae20b66b74619847c4df29 |
|
BLAKE2b-256 | dfdd1bf09815809af520707455ea2f59d061b427c9da216cfad55c2dcc53f8ad |
File details
Details for the file hypopt-1.0.9-py2.py3-none-any.whl
.
File metadata
- Download URL: hypopt-1.0.9-py2.py3-none-any.whl
- Upload date:
- Size: 13.1 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.19.1 setuptools/40.2.0 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.7.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 53bb5fd98d09b634101960778b3c6131a896cb2d6d30792388f56acf01ed7880 |
|
MD5 | aa1ab29fa04340025549f194375490fa |
|
BLAKE2b-256 | 6e8b17f9022d94066ec29ab0008ed1ad247615153e5c633c2787255cfe2e95b8 |