Skip to main content

Simple hyperparameter tuning in Python

Project description

hypt

Simple hyperparameter tuning in Python.

I wrote hypt as a minimalistic hyperparameter tuning library I could use for quick experiments, when established libraries like Optuna felt like overkill. My goal was to have it work as a simple and easy to debug for loop over hyperparameter values, instead of having to rewrite my whole training script around it.

As such, hypt will have a small footprint, and avoid implementing things like experiment tracking, results vizualization, parallelization, etc.. I will also probably focus more on "out of the beaten path" approaches to hyperparameter optimization.

Installation

hypt can be installed through pip:

pip install hypt

Getting Started with Random Search

The following is an illustrative example of tuning the parameters of a GBDT model using 50 trials of random search:

import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.ensemble import HistGradientBoostingRegressor
from tqdm import tqdm
import hypt
import hypt.random as r

# define random search
hparams = hypt.RandomSearch({
    'loss': 'squared_error',
    'learning_rate': r.LogUniform(0.001, 0.5),
    'max_iter': 200,
    'max_leaf_nodes': r.IntLogUniform(16, 256),
    'min_samples_leaf': r.IntLogUniform(1, 100),
    'l2_regularization': r.OrZero(r.LogUniform(0.001, 10)),  # half of samples will be 0
    'validation_fraction': 0.2,
    'n_iter_no_change': 10,
    'random_state': 1984,
}, num_samples=50, seed=123)

# get data
X, y = fetch_california_housing(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# hpt loop
val_results = []
test_results = []
for hparam in tqdm(hparams): # progress bar
    gbm = HistGradientBoostingRegressor(**hparam) # hparam is a simple dict
    gbm.fit(X_train, y_train)

    val_results.append(gbm.validation_score_.max())
    test_results.append(gbm.score(X_test, y_test))

# print best hparam and test score
best = np.argmax(val_results)
print('Best params:')
for k, v in hparams[best].items():
    print(f'\t{k} : {v}')
print('Test r2 score:', test_results[best])

Outputs:

100%|██████████| 50/50 [01:27<00:00,  1.76s/it]
Best params:
	learning_rate : 0.16311465153429477
	max_leaf_nodes : 33
	min_samples_leaf : 23
	l2_regularization : 0.06800582912648902
	loss : squared_error
	max_iter : 200
	validation_fraction : 0.2
	n_iter_no_change : 10
	random_state : 1984
Test r2 score: 0.8447968218784379

Informed Line Searches

Now lets try something a bit more unconventional, relying on my past experience with tuning GBDT models... We will restrict ourselves to tuning only 3 hyperparameters using something more similar to a grid search:

  1. The total number of leaf nodes
  2. The minimum number of samples per leaf
  3. The learning rate

The first two hyperparameters control the regularization of the model. We will always start our search with the most regularized version first (smaller number of leaf nodes and larger minimum samples per leaf). For each setting of the regularization parameters we will find the optimal learning rate using a golden-section search with 5 function evaluations. If we find that, as we reduce the amount of regularization provided by any of the parameters, the function value gets worse, we stop the search in that direction.

The code to implement this is the following:

import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.ensemble import HistGradientBoostingRegressor
from tqdm import tqdm

import hypt
from hypt.linesearch import NestedLineSearch, GoldenSearch, LineSearch

# define the nested line search
# note that the order of the dynamic parameters matters!
# the first parameter corresponds to the outermost loop
hparams = NestedLineSearch({
    'loss': 'squared_error',
    'max_iter': 200,
    'validation_fraction': 0.2,
    'n_iter_no_change': 10,
    'random_state': 1984,
    'l2_regularization': 0,
    'max_leaf_nodes': LineSearch([16, 32, 64, 128, 256], patience=1),
    'min_samples_leaf': LineSearch([30, 10, 1], patience=1),
    'learning_rate': GoldenSearch(0.001, 0.5, num_evals=5, log=True),
})
# wrapper utility to automatically record objective values and parameters
hparams = hypt.Recorder(hparams) 

# get data
X, y = fetch_california_housing(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# hpt loop is similar
val_results = []
test_results = []
for hparam in tqdm(hparams): # progress bar
    gbm = HistGradientBoostingRegressor(**hparam) # hparam is a simple dict
    gbm.fit(X_train, y_train)

    val_results.append(gbm.validation_score_.max())
    test_results.append(gbm.score(X_test, y_test))

    # this is the main change as we need to provide feedback
    # (i.e., the function value) for the search procedure
    hparams.feedback(-val_results[-1])

# print best hparam and test score
best = np.argmax(val_results)
print('Best params:')
for k, v in hparams.best_params().items():
    print(f'\t{k} : {v}')
print('Test r2 score:', test_results[hparams.best_iteration()])

which outputs:

30it [00:30,  1.01s/it]
Best params:
	max_leaf_nodes : 32
	min_samples_leaf : 30
	learning_rate : 0.1153000814478148
	loss : squared_error
	max_iter : 200
	validation_fraction : 0.2
	n_iter_no_change : 10
	random_state : 1984
	l2_regularization : 0
Test r2 score: 0.8449648371371464

We obtained a similar test $R^2$ score and found similar hyperparameters in 1/3 of the time! Note that this speedup was only possible due to the early stopping in the line searches, allowing us to evaluate only the three smaller values of max_leaf_nodes. We also only evaluated the first two values of min_samples_leaf for each value of max_leaf_nodes. A full grid search would have required evaluating 75 different hyperparameter configurations.

Future Developments

Eventually I hope to have the time to implement some more hyperparameter search methods. This could include the ever popular TPE but also other more unconventional local search approaches.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hypt-0.2.0.tar.gz (13.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hypt-0.2.0-py3-none-any.whl (12.4 kB view details)

Uploaded Python 3

File details

Details for the file hypt-0.2.0.tar.gz.

File metadata

  • Download URL: hypt-0.2.0.tar.gz
  • Upload date:
  • Size: 13.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for hypt-0.2.0.tar.gz
Algorithm Hash digest
SHA256 4afbb792eaab73aadcea62e774514769de44b74af51fb8c707c52050f013ea50
MD5 36d3d7c6fc1bec7c729b48ca24c40e4c
BLAKE2b-256 d22486ab034c43e6604d63ad3b403a9d3873ff53d74d6b4024f7c8c4c22fcc06

See more details on using hashes here.

Provenance

The following attestation bundles were made for hypt-0.2.0.tar.gz:

Publisher: publish.yaml on Ajoo/hypt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file hypt-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: hypt-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 12.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for hypt-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6a982335c5c2142867c4bb4a48630bbc2e5f89962325f6b524e34a78dde6d5c2
MD5 3930f2ec9b3be3fc517052eee113f9fb
BLAKE2b-256 ec7591dd7709001ac194d1ed4bafce4ef46d57b6fdf9b71767df305138c162da

See more details on using hashes here.

Provenance

The following attestation bundles were made for hypt-0.2.0-py3-none-any.whl:

Publisher: publish.yaml on Ajoo/hypt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page