Extensions for catboost models
Project description
Catboost-extensions
This library provides an easy-to-use interface for hyperparameter tuning of CatBoost models using Optuna. The OptunaTuneCV
class simplifies the process of defining parameter spaces, configuring trials, and running cross-validation with CatBoost.
Installation
To install the library, use pip:
pip install catboost-extensions
Quick Start Guide
OptunaTuneCV
Here is an example of how to use the library to tune a CatBoost model using Optuna:
1. Import necessary libraries
from pprint import pprint
import pandas as pd
from catboost_extensions.optuna import (
OptunaTuneCV,
CatboostParamSpace,
)
from catboost import CatBoostRegressor
from sklearn.datasets import fetch_california_housing
import optuna
2. Load and prepare your data
# Load dataset
data = fetch_california_housing()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target
3. Define your CatBoost model
model = CatBoostRegressor(verbose=False, task_type='CPU')
4. Define the parameter space
The CatboostParamSpace
class allows you to define a parameter space for your CatBoost model. You can remove parameters that you don't want to tune using the del_params
method.
param_space = CatboostParamSpace(params_preset='general', task_type='CPU')
param_space.del_params(['depth', 'l2_leaf_reg'])
pprint(param_space.get_params_space())
Out:
{'bootstrap_type': CategoricalDistribution(choices=('Bayesian', 'MVS', 'Bernoulli', 'No')),
'grow_policy': CategoricalDistribution(choices=('SymmetricTree', 'Depthwise', 'Lossguide')),
'iterations': IntDistribution(high=5000, log=False, low=100, step=1),
'learning_rate': FloatDistribution(high=0.1, log=True, low=0.001, step=None),
'max_bin': IntDistribution(high=512, log=False, low=8, step=1),
'random_strength': FloatDistribution(high=10.0, log=True, low=0.01, step=None),
'rsm': FloatDistribution(high=1.0, log=False, low=0.01, step=None),
'score_function': CategoricalDistribution(choices=('Cosine', 'L2'))}
Also you can change the default values of the parameters:
param_space.iterations=(1000, 2000)
5. Set up the OptunaTuneCV
objective
The OptunaTuneCV
class helps to define an objective function for Optuna. You can specify the CatBoost model, the parameter space, the dataset, and other options such as the trial timeout and the scoring metric.
objective = OptunaTuneCV(model, param_space, X, y, trial_timeout=10, scoring='r2')
6. Create an Optuna study and optimize
You can choose an Optuna sampler (e.g., TPESampler
) and then create a study to optimize the objective function.
sampler = optuna.samplers.TPESampler(seed=20, multivariate=True)
study = optuna.create_study(direction='maximize', sampler=sampler)
study.optimize(objective, n_trials=10)
7. View the results
After the study completes, you can analyze the results to see the best hyperparameters found during the optimization.
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(f" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")
Contributing
If you want to contribute to this library, please open an issue or submit a pull request on GitHub.
License
This project is licensed under the MIT License.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file catboost_extensions-2.1.tar.gz
.
File metadata
- Download URL: catboost_extensions-2.1.tar.gz
- Upload date:
- Size: 14.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 54fd5a94c251bc07a45c63b5d585c03f9905e1eef818c24fd408bb562c328672 |
|
MD5 | aff7d52edaa331d7b74eefd68d627a1e |
|
BLAKE2b-256 | b4eee4c5313ab53ff1675ade908fec58b9347e0601693fedae6c3a030dc8e36f |
File details
Details for the file catboost_extensions-2.1-py3-none-any.whl
.
File metadata
- Download URL: catboost_extensions-2.1-py3-none-any.whl
- Upload date:
- Size: 14.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 25019c8bbed26634c36956eca55431bd96d9563c4094689129320f370018aa0d |
|
MD5 | 88459c2989ac2f20781577c585767cc3 |
|
BLAKE2b-256 | 52bf7a617322d11e8141b6ca09bbf34514fba0db9bb4119509a4bbe28b9afa7c |