Skip to main content

Extensions for catboost models

Project description

Catboost-extensions


This library provides an easy-to-use interface for hyperparameter tuning of CatBoost models using Optuna. The OptunaTuneCV class simplifies the process of defining parameter spaces, configuring trials, and running cross-validation with CatBoost.

Installation

To install the library, use pip:

pip install catboost-extensions

Quick Start Guide

OptunaTuneCV

Here is an example of how to use the library to tune a CatBoost model using Optuna:

1. Import necessary libraries

from pprint import pprint

import pandas as pd

from catboost_extensions.optuna import (
    OptunaTuneCV, 
    CatboostParamSpace,
)
from catboost import CatBoostRegressor
from sklearn.datasets import fetch_california_housing
import optuna

2. Load and prepare your data

# Load dataset
data = fetch_california_housing()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target

3. Define your CatBoost model

model = CatBoostRegressor(verbose=False, task_type='CPU')

4. Define the parameter space

The CatboostParamSpace class allows you to define a parameter space for your CatBoost model. You can remove parameters that you don't want to tune using the del_params method.

param_space = CatboostParamSpace(params_preset='general', task_type='CPU')
param_space.del_params(['depth', 'l2_leaf_reg'])
pprint(param_space.get_params_space())

Out:

{'bootstrap_type': CategoricalDistribution(choices=('Bayesian', 'MVS', 'Bernoulli', 'No')),
 'grow_policy': CategoricalDistribution(choices=('SymmetricTree', 'Depthwise', 'Lossguide')),
 'iterations': IntDistribution(high=5000, log=False, low=100, step=1),
 'learning_rate': FloatDistribution(high=0.1, log=True, low=0.001, step=None),
 'max_bin': IntDistribution(high=512, log=False, low=8, step=1),
 'random_strength': FloatDistribution(high=10.0, log=True, low=0.01, step=None),
 'rsm': FloatDistribution(high=1.0, log=False, low=0.01, step=None),
 'score_function': CategoricalDistribution(choices=('Cosine', 'L2'))}

Also you can change the default values of the parameters:

param_space.iterations=(1000, 2000)

5. Set up the OptunaTuneCV objective

The OptunaTuneCV class helps to define an objective function for Optuna. You can specify the CatBoost model, the parameter space, the dataset, and other options such as the trial timeout and the scoring metric.

objective = OptunaTuneCV(model, param_space, X, y, trial_timeout=10, scoring='r2')

6. Create an Optuna study and optimize

You can choose an Optuna sampler (e.g., TPESampler) and then create a study to optimize the objective function.

sampler = optuna.samplers.TPESampler(seed=20, multivariate=True)
study = optuna.create_study(direction='maximize', sampler=sampler)
study.optimize(objective, n_trials=10)

7. View the results

After the study completes, you can analyze the results to see the best hyperparameters found during the optimization.

print("Best trial:")
trial = study.best_trial
print(f"  Value: {trial.value}")
print(f"  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

Contributing

If you want to contribute to this library, please open an issue or submit a pull request on GitHub.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

catboost_extensions-1.1.1.tar.gz (14.9 kB view details)

Uploaded Source

Built Distribution

catboost_extensions-1.1.1-py3-none-any.whl (14.6 kB view details)

Uploaded Python 3

File details

Details for the file catboost_extensions-1.1.1.tar.gz.

File metadata

  • Download URL: catboost_extensions-1.1.1.tar.gz
  • Upload date:
  • Size: 14.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.7

File hashes

Hashes for catboost_extensions-1.1.1.tar.gz
Algorithm Hash digest
SHA256 524435907a706b2802ca33d2654c0dccb5768979a433907bc0026949d9488b2a
MD5 6f8348bda8cc2359af58ca47fc8c62c3
BLAKE2b-256 44132dbea9c086f83e32e304a9f426be4f947534b86d79cbbe199d20c8a02c04

See more details on using hashes here.

File details

Details for the file catboost_extensions-1.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for catboost_extensions-1.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f9450f11b441c3d609ed6cdfab9ba977e63eb0969065e755b43f6c008ffcf319
MD5 099c3fcb6f134149b643dbf1de009e48
BLAKE2b-256 44097008b7de0505c0fad88e1b86e845a35bcabb1c419ee9a8783e5209818bea

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page