The package ntrees_tuning is an extension to sklearn. To Random Forests and Gradient Boosting it adds the ntrees parameter which gives control over how many trees are used for prediction. The main benefit is that it enables to tune the ntrees parameter w.r.t. the OOB-error without having to retrain a new model for each value of ntrees.
Project description
The package ntrees_tuning is an extension to sklearn. To Random Forests and Gradient Boosting it adds the ntrees parameter which gives control over how many trees are used for prediction. The main benefit is that it enables to tune the ntrees parameter w.r.t. the OOB-error without having to retrain a new model for each value of ntrees.
The package introduces subclasses to the sklearn-classes of Random Forest and Gradient Boosting (RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, GradientBoostingRegressor). Each adds two new methods called predict_ntree and tune_ntree which enable predicting and tuning the ntrees parameter possible
Example usage:
1. Create data:
from sklearn.datasets import make_classification, make_regression
Xcls, ycls = make_classification(n_samples=200, n_features=20, n_classes=3, random_state=42, n_clusters_per_class=3, n_informative=5)
Xreg, yreg = make_regression(n_samples=200, n_features=20, random_state=42)
2. Create and Fit RandomForest and GradientBoosting models for Regression and Classification
For tuning the ntrees parameter new custom classes are introduced. They are direct descendants of sklearn classes (RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, GradientBoostingRegressor).
import ntree_tuning as ntt
rf_cls = ntt.Ntree_RandForest_Classifier(n_estimators=100)
rf_cls.fit(Xcls, ycls)
rf_reg = ntt.Ntree_RandForest_Regressor(n_estimators=100)
rf_reg.fit(Xreg, yreg)
gb_cls = ntt.Ntree_GradBoost_Classifier(n_estimators=100, subsample=0.8)
gb_cls.fit(Xcls, ycls)
gb_reg = ntt.Ntree_GradBoost_Regressor(n_estimators=100, subsample=0.8)
gb_reg.fit(Xreg, yreg)
3. Tune ntrees
You then can call the tune_ntrees method to get a dictionary of the pairs of the ntrees value and the oob-error.
# Gradient Boosting
print(gb_reg.tune_ntrees())
print(gb_cls.tune_ntrees())
# Random Forests
min_trees = 20
max_trees = 80
delta_trees = 5
print(rf_reg.tune_ntrees(Xreg, yreg, min_trees, max_trees, delta_trees))
print(rf_cls.tune_ntrees(Xcls, ycls, min_trees, max_trees, delta_trees))
4. Predict with ntrees
print(gb_reg.predict_ntrees(Xreg, ntrees=10))
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ntrees_tuning-0.1.1.tar.gz.
File metadata
- Download URL: ntrees_tuning-0.1.1.tar.gz
- Upload date:
- Size: 8.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
19be59831a55d72db866e3a7ce27357be0124f3be68aa92c78c68aeb5e617983
|
|
| MD5 |
3044879d991b7d714e98a6822b669a08
|
|
| BLAKE2b-256 |
439c6f9cf07d2994f8be4f5efcc2b62b233ad4129064bdf83ff8249ac2acb6cb
|
File details
Details for the file ntrees_tuning-0.1.1-py3-none-any.whl.
File metadata
- Download URL: ntrees_tuning-0.1.1-py3-none-any.whl
- Upload date:
- Size: 9.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
19714b8ecf2f9f81b2f44bd4a0fc88da6fc3797faa2fcfa1b0eee370ccd27f23
|
|
| MD5 |
00dd2648d608b55f4a0105a983fa1b82
|
|
| BLAKE2b-256 |
5ddc11ab6702f870515dd2a592b86209d415ed8ac47cbd44a71fca90b01de980
|