Skip to main content

Extensions for catboost models

Project description

Catboost-extensions


This library provides an easy-to-use interface for hyperparameter tuning of CatBoost models using Optuna. The OptunaTuneCV class simplifies the process of defining parameter spaces, configuring trials, and running cross-validation with CatBoost.

Installation

To install the library, use pip:

pip install catboost-extensions

Quick Start Guide

OptunaTuneCV

Here is an example of how to use the library to tune a CatBoost model using Optuna:

1. Import necessary libraries

from pprint import pprint

import pandas as pd

from catboost_extensions.optuna import (
    OptunaTuneCV, 
    CatboostParamSpace,
)
from catboost import CatBoostRegressor
from sklearn.datasets import fetch_california_housing
import optuna

2. Load and prepare your data

# Load dataset
data = fetch_california_housing()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target

3. Define your CatBoost model

model = CatBoostRegressor(verbose=False, task_type='CPU')

4. Define the parameter space

The CatboostParamSpace class allows you to define a parameter space for your CatBoost model. You can remove parameters that you don't want to tune using the del_params method.

param_space = CatboostParamSpace(params_preset='general', task_type='CPU')
param_space.del_params(['depth', 'l2_leaf_reg'])
pprint(param_space.get_params_space())

Out:

{'bootstrap_type': CategoricalDistribution(choices=('Bayesian', 'MVS', 'Bernoulli', 'No')),
 'grow_policy': CategoricalDistribution(choices=('SymmetricTree', 'Depthwise', 'Lossguide')),
 'iterations': IntDistribution(high=5000, log=False, low=100, step=1),
 'learning_rate': FloatDistribution(high=0.1, log=True, low=0.001, step=None),
 'max_bin': IntDistribution(high=512, log=False, low=8, step=1),
 'random_strength': FloatDistribution(high=10.0, log=True, low=0.01, step=None),
 'rsm': FloatDistribution(high=1.0, log=False, low=0.01, step=None),
 'score_function': CategoricalDistribution(choices=('Cosine', 'L2'))}

Also you can change the default values of the parameters:

param_space.iterations=(1000, 2000)

5. Set up the OptunaTuneCV objective

The OptunaTuneCV class helps to define an objective function for Optuna. You can specify the CatBoost model, the parameter space, the dataset, and other options such as the trial timeout and the scoring metric.

objective = OptunaTuneCV(model, param_space, X, y, trial_timeout=10, scoring='r2')

6. Create an Optuna study and optimize

You can choose an Optuna sampler (e.g., TPESampler) and then create a study to optimize the objective function.

sampler = optuna.samplers.TPESampler(seed=20, multivariate=True)
study = optuna.create_study(direction='maximize', sampler=sampler)
study.optimize(objective, n_trials=10)

7. View the results

After the study completes, you can analyze the results to see the best hyperparameters found during the optimization.

print("Best trial:")
trial = study.best_trial
print(f"  Value: {trial.value}")
print(f"  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

Contributing

If you want to contribute to this library, please open an issue or submit a pull request on GitHub.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

catboost_extensions-2.1.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

catboost_extensions-2.1-py3-none-any.whl (14.0 kB view details)

Uploaded Python 3

File details

Details for the file catboost_extensions-2.1.tar.gz.

File metadata

  • Download URL: catboost_extensions-2.1.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.7

File hashes

Hashes for catboost_extensions-2.1.tar.gz
Algorithm Hash digest
SHA256 54fd5a94c251bc07a45c63b5d585c03f9905e1eef818c24fd408bb562c328672
MD5 aff7d52edaa331d7b74eefd68d627a1e
BLAKE2b-256 b4eee4c5313ab53ff1675ade908fec58b9347e0601693fedae6c3a030dc8e36f

See more details on using hashes here.

File details

Details for the file catboost_extensions-2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for catboost_extensions-2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 25019c8bbed26634c36956eca55431bd96d9563c4094689129320f370018aa0d
MD5 88459c2989ac2f20781577c585767cc3
BLAKE2b-256 52bf7a617322d11e8141b6ca09bbf34514fba0db9bb4119509a4bbe28b9afa7c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page