Skip to main content

Extensions for catboost models

Project description

Catboost-extensions


This library provides an easy-to-use interface for hyperparameter tuning of CatBoost models using Optuna. The OptunaTuneCV class simplifies the process of defining parameter spaces, configuring trials, and running cross-validation with CatBoost.

Installation

To install the library, use pip:

pip install catboost-extensions

Quick Start Guide

OptunaTuneCV

Here is an example of how to use the library to tune a CatBoost model using Optuna:

1. Import necessary libraries

from pprint import pprint

import pandas as pd

from catboost_extensions.optuna import (
    OptunaTuneCV, 
    CatboostParamSpace,
)
from catboost import CatBoostRegressor
from sklearn.datasets import fetch_california_housing
import optuna

2. Load and prepare your data

# Load dataset
data = fetch_california_housing()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target

3. Define your CatBoost model

model = CatBoostRegressor(verbose=False, task_type='CPU')

4. Define the parameter space

The CatboostParamSpace class allows you to define a parameter space for your CatBoost model. You can remove parameters that you don't want to tune using the del_params method.

param_space = CatboostParamSpace(params_preset='general', task_type='CPU')
param_space.del_params(['depth', 'l2_leaf_reg'])
pprint(param_space.get_params_space())

Out:

{'bootstrap_type': CategoricalDistribution(choices=('Bayesian', 'MVS', 'Bernoulli', 'No')),
 'grow_policy': CategoricalDistribution(choices=('SymmetricTree', 'Depthwise', 'Lossguide')),
 'iterations': IntDistribution(high=5000, log=False, low=100, step=1),
 'learning_rate': FloatDistribution(high=0.1, log=True, low=0.001, step=None),
 'max_bin': IntDistribution(high=512, log=False, low=8, step=1),
 'random_strength': FloatDistribution(high=10.0, log=True, low=0.01, step=None),
 'rsm': FloatDistribution(high=1.0, log=False, low=0.01, step=None),
 'score_function': CategoricalDistribution(choices=('Cosine', 'L2'))}

Also you can change the default values of the parameters:

param_space.iterations=(1000, 2000)

5. Set up the OptunaTuneCV objective

The OptunaTuneCV class helps to define an objective function for Optuna. You can specify the CatBoost model, the parameter space, the dataset, and other options such as the trial timeout and the scoring metric.

objective = OptunaTuneCV(model, param_space, X, y, trial_timeout=10, scoring='r2')

6. Create an Optuna study and optimize

You can choose an Optuna sampler (e.g., TPESampler) and then create a study to optimize the objective function.

sampler = optuna.samplers.TPESampler(seed=20, multivariate=True)
study = optuna.create_study(direction='maximize', sampler=sampler)
study.optimize(objective, n_trials=10)

7. View the results

After the study completes, you can analyze the results to see the best hyperparameters found during the optimization.

print("Best trial:")
trial = study.best_trial
print(f"  Value: {trial.value}")
print(f"  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

Contributing

If you want to contribute to this library, please open an issue or submit a pull request on GitHub.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

catboost_extensions-2.8.tar.gz (26.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

catboost_extensions-2.8-py3-none-any.whl (26.2 kB view details)

Uploaded Python 3

File details

Details for the file catboost_extensions-2.8.tar.gz.

File metadata

  • Download URL: catboost_extensions-2.8.tar.gz
  • Upload date:
  • Size: 26.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.7

File hashes

Hashes for catboost_extensions-2.8.tar.gz
Algorithm Hash digest
SHA256 8b353adb1c2fb1cdf0c60ac749dca2099cd9cec0de73134f699ec24c7bf52d87
MD5 daab2559d088ed201b3a7c014ad0cd99
BLAKE2b-256 5954b28c22a01996a26695bda9573316c90d665457b7b233904439e0a814c704

See more details on using hashes here.

File details

Details for the file catboost_extensions-2.8-py3-none-any.whl.

File metadata

File hashes

Hashes for catboost_extensions-2.8-py3-none-any.whl
Algorithm Hash digest
SHA256 9c6212230f952be4fdfebf3e3c533ba0536250ee012df510c204851028a118e4
MD5 2517726d20b6779dc48c8818f3305a34
BLAKE2b-256 6fe681c93b37d5854e78fa4d85c1bf667381f843fb0292c5c2f4295b348e4353

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page