Skip to main content

Grid search hyper-parameter optimization using a validation set (not cross validation)

Project description

pypi py_versions build_status coverage

A Python machine learning package for grid search hyper-parameter optimization using a validation set (defaults to cross validation when no validation set is available). This package works for Python 2.7+ and Python 3+, for any model (classification and regression), and runs in parallel on all threads on your CPU automatically.

scikit-learn provides a package for grid-search hyper-parameter optimization **using cross-validation** on the training dataset. Unfortunately, cross-validation is impractically slow for large datasets and fails for small datasets due to the lack of data in each class needed to properly train each fold. Instead, we use a constant validation set to optimize hyper-parameters – the hypopt package makes this fast (distributed on all CPU threads) and easy (one line of code).

hypopt.model_selection.fit_model_with_grid_search supports grid search hyper-parameter optimization when you already have a validation set, eliminating the extra hours of training time required when using cross-validation. However, when no validation set is given, it defaults to using cross-validation on the training set. This allows you to alows use hypopt anytime you need to do hyper-parameter optimization with grid-search, regardless of whether you use a validation set or cross-validation.

Installation

Python 2.7, 3.4, 3.5, and 3.6 are supported.

Stable release:

$ pip install hypopt

Developer (unstable) release:

$ pip install git+https://github.com/cgnorthcutt/hypopt.git

To install the codebase (enabling you to make modifications):

$ conda update pip # if you use conda
$ git clone https://github.com/cgnorthcutt/hypopt.git
$ cd hypopt
$ pip install -e .

Examples

Basic usage

# Assuming you already have train, test, val sets and a model.
from hypopt import GridSearch
param_grid = [
  {'C': [1, 10, 100], 'kernel': ['linear']},
  {'C': [1, 10, 100], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']},
 ]
# Grid-search all parameter combinations using a validation set.
gs = GridSearch(model = SVR())
gs.fit(X_train, y_train, param_grid, X_val, y_val)
print('Test Score for Optimized Parameters:', gs.score(X_test, y_test))

Minimal working examples

Other Examples including a working example with MNIST

Use hypopt with any model (PyTorch, Tensorflow, caffe2, scikit-learn, etc.)

All of the features of the hypopt package work with any model. Yes, any model. Feel free to use PyTorch, Tensorflow, caffe2, scikit-learn, mxnet, etc. If you use a scikit-learn model, all hypopt methods will work out-of-the-box. It’s also easy to use your favorite model from a non-scikit-learn package, just wrap your model into a Python class that inherets the sklearn.base.BaseEstimator. Here’s an example for a generic classifier:

from sklearn.base import BaseEstimator
class YourModel(BaseEstimator): # Inherits sklearn base classifier
    def __init__(self, ):
        pass
    def fit(self, X, y, sample_weight = None):
        pass
    def predict(self, X):
        pass
    def score(self, X, y, sample_weight = None):
        pass

    # Inherting BaseEstimator gives you these for free!
    # So if you inherit, there's no need to implement these.
    def get_params(self, deep = True):
        pass
    def set_params(self, **params):
        pass

PyTorch MNIST CNN Example

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hypopt-1.0.4.tar.gz (7.8 kB view details)

Uploaded Source

Built Distribution

hypopt-1.0.4-py2.py3-none-any.whl (10.7 kB view details)

Uploaded Python 2Python 3

File details

Details for the file hypopt-1.0.4.tar.gz.

File metadata

  • Download URL: hypopt-1.0.4.tar.gz
  • Upload date:
  • Size: 7.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.14.2 setuptools/36.4.0 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/2.7.13

File hashes

Hashes for hypopt-1.0.4.tar.gz
Algorithm Hash digest
SHA256 79ca13d7769fad050d9a9dc290373a9c5568505a2d951341032f18849ab370c3
MD5 14031ddfde2b00af2475663d7a85ed2c
BLAKE2b-256 e34fc933545566f0d982c9dc3771004bf6bc84cb3ab4be59d3c81a4705831b6b

See more details on using hashes here.

File details

Details for the file hypopt-1.0.4-py2.py3-none-any.whl.

File metadata

  • Download URL: hypopt-1.0.4-py2.py3-none-any.whl
  • Upload date:
  • Size: 10.7 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.14.2 setuptools/36.4.0 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/2.7.13

File hashes

Hashes for hypopt-1.0.4-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 7a6a99ddc995f3cc0e9973fd087934c4640a909d1babec017f887f5e43ff6a38
MD5 6cdeec55b908377cfa0bd54dca186270
BLAKE2b-256 3428166cd4fc6f6679e0d2c154511a55ddc2e85eb27bf6966dd3bef9738ee9da

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page