Skip to main content

Extensions for catboost models

Project description

Catboost-extensions


This library provides an easy-to-use interface for hyperparameter tuning of CatBoost models using Optuna. The OptunaTuneCV class simplifies the process of defining parameter spaces, configuring trials, and running cross-validation with CatBoost.

Installation

To install the library, use pip:

pip install catboost-extensions

Quick Start Guide

OptunaTuneCV

Here is an example of how to use the library to tune a CatBoost model using Optuna:

1. Import necessary libraries

from pprint import pprint

import pandas as pd

from catboost_extensions.optuna import (
    OptunaTuneCV, 
    CatboostParamSpace,
)
from catboost import CatBoostRegressor
from sklearn.datasets import fetch_california_housing
import optuna

2. Load and prepare your data

# Load dataset
data = fetch_california_housing()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target

3. Define your CatBoost model

model = CatBoostRegressor(verbose=False, task_type='CPU')

4. Define the parameter space

The CatboostParamSpace class allows you to define a parameter space for your CatBoost model. You can remove parameters that you don't want to tune using the del_params method.

param_space = CatboostParamSpace(params_preset='general', task_type='CPU')
param_space.del_params(['depth', 'l2_leaf_reg'])
pprint(param_space.get_params_space())

Out:

{'bootstrap_type': CategoricalDistribution(choices=('Bayesian', 'MVS', 'Bernoulli', 'No')),
 'grow_policy': CategoricalDistribution(choices=('SymmetricTree', 'Depthwise', 'Lossguide')),
 'iterations': IntDistribution(high=5000, log=False, low=100, step=1),
 'learning_rate': FloatDistribution(high=0.1, log=True, low=0.001, step=None),
 'max_bin': IntDistribution(high=512, log=False, low=8, step=1),
 'random_strength': FloatDistribution(high=10.0, log=True, low=0.01, step=None),
 'rsm': FloatDistribution(high=1.0, log=False, low=0.01, step=None),
 'score_function': CategoricalDistribution(choices=('Cosine', 'L2'))}

Also you can change the default values of the parameters:

param_space.iterations=(1000, 2000)

5. Set up the OptunaTuneCV objective

The OptunaTuneCV class helps to define an objective function for Optuna. You can specify the CatBoost model, the parameter space, the dataset, and other options such as the trial timeout and the scoring metric.

objective = OptunaTuneCV(model, param_space, X, y, trial_timeout=10, scoring='r2')

6. Create an Optuna study and optimize

You can choose an Optuna sampler (e.g., TPESampler) and then create a study to optimize the objective function.

sampler = optuna.samplers.TPESampler(seed=20, multivariate=True)
study = optuna.create_study(direction='maximize', sampler=sampler)
study.optimize(objective, n_trials=10)

7. View the results

After the study completes, you can analyze the results to see the best hyperparameters found during the optimization.

print("Best trial:")
trial = study.best_trial
print(f"  Value: {trial.value}")
print(f"  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

Contributing

If you want to contribute to this library, please open an issue or submit a pull request on GitHub.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

catboost_extensions-1.1.0.tar.gz (14.8 kB view details)

Uploaded Source

Built Distribution

catboost_extensions-1.1.0-py3-none-any.whl (14.6 kB view details)

Uploaded Python 3

File details

Details for the file catboost_extensions-1.1.0.tar.gz.

File metadata

  • Download URL: catboost_extensions-1.1.0.tar.gz
  • Upload date:
  • Size: 14.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.7

File hashes

Hashes for catboost_extensions-1.1.0.tar.gz
Algorithm Hash digest
SHA256 d114e0b90bbf9e2ed18416a52fb1b654b750259735692df478f391e953b43cab
MD5 a50b44d34d8ad67986f9f5bfb42b18fa
BLAKE2b-256 e7dfc2c8bdcdd7e659d1dad6a8ccf905b39807d02a8d69a337bb1f9b23f6629f

See more details on using hashes here.

File details

Details for the file catboost_extensions-1.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for catboost_extensions-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ec092259a0e2e1f83365249bfd6472e7df46a16a52311f5af47d5159bd3fd04e
MD5 e9ecb6ef6a176c0f35cc8042760e2f54
BLAKE2b-256 3dcd82fb0d6862f8d93d61b20cc8eae447f754df5b87d166ee076fdf1a99378a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page