scikit-learn compatible alternative random forests algorithms
Project description
WildWood
is a python package providing improved random forest algorithms for
multiclass classification and regression introduced in the paper Wildwood: a new
random forest algorithm by S. Gaïffas, I. Merad and Y. Yu (2021).
It follows scikit-learn
's API and can be used as an inplace replacement for its
Random Forest algorithms (although multilabel/multiclass training is not supported yet).
WildWood
mainly provides, compared to standard Random Forest algorithms, the
following things:
- Improved predictions with less trees
- Faster training times (using a histogram strategy similar to LightGBM)
- Native support for categorical features
- Parallel training of the trees in the forest
Multi-class classification can be performed with WildWood
using ForestClassifier
while regression can be performed with ForestRegressor
.
Documentation
Documentation is available here:
Installation
The easiest way to install wildwood is using pip
pip install wildwood
But you can also use the latest development from github directly with
pip install git+https://github.com/pyensemble/wildwood.git
Basic usage
Basic usage follows the standard scikit-learn API. You can simply use
from wildwood import ForestClassifier
clf = ForestClassifier()
clf.fit(X_train, y_train)
y_pred = clf.predict_proba(X_test)[:, 1]
to train a classifier with all default hyper-parameters. However, let us pinpoint below some of the most interesting ones.
Categorical features
You should avoid one-hot encoding of categorical features and specify instead to
WildWood
which features should be considered as categorical.
This is done using the categorical_features
argument, which is either a boolean mask
or an array of indices corresponding to the categorical features.
from wildwood import ForestClassifier
# Assuming columns 0 and 2 are categorical in X
clf = ForestClassifier(categorical_features=[0, 2])
clf.fit(X_train, y_train)
y_pred = clf.predict_proba(X_test)[:, 1]
For now, `WildWood` will actually use a maximum of 256 modalities for categorical
features, since internally features are encoded using a memory efficient ``uint8`` data
type. This will change in a near future.
Improved predictions through aggregation with exponential weights
By default (aggregation=True
) the predictions produced by WildWood
are an
aggregation with exponential weights (computed on out-of-bag samples) of the predictions
given by all the possible prunings of each tree. This is computed exactly and very
efficiently, at a cost nearly similar to that of a standard Random Forest (which
averages the prediction of leaves).
See {ref}description-wildwood
for a deeper description of WildWood
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file wildwood-0.3.tar.gz
.
File metadata
- Download URL: wildwood-0.3.tar.gz
- Upload date:
- Size: 21.8 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.0 CPython/3.9.13 Darwin/23.1.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 360e7095c4ee0a36c927eb3bd307fffcbf0c3547c3d0b95fa778042885be5ed7 |
|
MD5 | 9c83b23d8396aa14ce1004385ede34bc |
|
BLAKE2b-256 | b48559aea451d6a0c5609bf75cad19f85deeba09c281dfdfd27ab28b4a5cf11b |
File details
Details for the file wildwood-0.3-py3-none-any.whl
.
File metadata
- Download URL: wildwood-0.3-py3-none-any.whl
- Upload date:
- Size: 21.8 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.0 CPython/3.9.13 Darwin/23.1.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 516eefd462a907d4dd9f80e0935e3d04fc9df79492fb1d4867e4cdf7253fcc58 |
|
MD5 | 1e9068a25cdc8bcd62ff8c175e2a710f |
|
BLAKE2b-256 | 5adf23e474ce101bc532f14f81be1931dc9bde0f264786b4169cdbdd1aef7ff6 |