A collection of LightGBM callbacks.
Project description
LightGBM Callbacks
A collection of LightGBM callbacks.
Provides implementations of ProgressBarCallback
(#5867) and DartEarlyStoppingCallback
(#4805), as well as an LGBMDartEarlyStoppingEstimator
that automatically passes these callbacks. (#3313, #5808)
Installation
Install this via pip (or your favourite package manager):
pip install lightgbm-callbacks
Usage
SciKit-Learn API, simple
from lightgbm import LGBMRegressor
from lightgbm_callbacks import LGBMDartEarlyStoppingEstimator
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
LGBMDartEarlyStoppingEstimator(
LGBMRegressor(boosting_type="dart"), # or "gbdt", ...
stopping_rounds=10, # or n_iter_no_change=10
test_size=0.2, # or validation_fraction=0.2
shuffle=False,
tqdm_cls="rich", # "auto", "autonotebook", ...
).fit(X_train, y_train)
Scikit-Learn API, manually passing callbacks
from lightgbm import LGBMRegressor
from lightgbm_callbacks import ProgressBarCallback, DartEarlyStoppingCallback
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train)
early_stopping_callback = DartEarlyStoppingCallback(stopping_rounds=10)
LGBMRegressor(
).fit(
X_train,
y_train,
eval_set=[(X_train, y_train), (X_val, y_val)],
callbacks=[
early_stopping_callback,
ProgressBarCallback(early_stopping_callback=early_stopping_callback),
],
)
Details on DartEarlyStoppingCallback
Below is a description of the DartEarlyStoppingCallback
method
parameter and lgb.plot_metric
for each lgb.LGBMRegressor(boosting_type="dart", n_estimators=1000)
trained with entire sklearn_datasets.load_diabetes()
dataset.
Method | Description | iteration | Image | Actual iteration |
---|---|---|---|---|
(Baseline) | If Early stopping is not used. | n_estimators |
1000 | |
"none" |
Do nothing and return the original estimator. | min(best_iteration + early_stopping_rounds, n_estimators) |
50 | |
"save" |
Save the best model by deepcopying the estimator and return the best model (using pickle ). |
min(best_iteration + 1, n_estimators) |
21 | |
"refit" |
Refit the estimator with the best iteration and return the refitted estimator. | min(best_iteration, n_estimators) |
20 |
Contributors ✨
Thanks goes to these wonderful people (emoji key):
34j 💻 🤔 📖 |
This project follows the all-contributors specification. Contributions of any kind welcome!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for lightgbm_callbacks-0.1.6-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c671acd230ca8b0c6067a3e1078240482859cf1a5eab04ac32df421b1b6a4012 |
|
MD5 | 7ffc02abfea8b786bd810c4bc05562da |
|
BLAKE2b-256 | b96e916829b692c7f147b7de9fc40c26a6dbd68dde9ce1d6724e03736b23ebbd |