Tree Weighting, Accuracy and Diversity-Preserving Pruning for Random Forests
Project description
WRFO: Weighted Random Forest Optimization
WRFO (Weighted Random Forest Optimization) is a machine learning algorithm that enhances Random Forest classifiers through tree weighting, accuracy and diversity-preserving pruning using Particle Swarm Optimization (PSO). By optimizing tree weights based on a multi-objective function combining accuracy, diversity, and feature entropy, WRFO achieves improved performance while using fewer trees.
Paper: "Towards Better Random Forests with Tree Weighting, Accuracy and Diversity-Preserving Pruning" (Expert Systems with Applications, under revision)
Key Features
- Multi-Objective Optimization: Balances accuracy, ensemble diversity, and feature entropy
- Adaptive Tree Selection: Automatically identifies and weights the most valuable trees
- Scikit-learn Compatible: Drop-in replacement for sklearn's RandomForestClassifier
- Parallel Processing: Efficient computation using joblib parallelization
- Research-Validated: under revision in Expert Systems with Applications
Installation
From source
git clone https://github.com/wajdidhifli/WRFO.git
cd WRFO
pip install -r requirements.txt
pip install -e .
Requirements
- Python >= 3.7
- numpy >= 1.20.0
- pandas >= 1.3.0
- scikit-learn >= 1.0.0
- pyswarm >= 0.6
- scipy >= 1.7.0
- joblib >= 1.0.0
- tqdm >= 4.62.0
Quick Start
from wrfo import WRFOClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# Load data
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
# Train WRFO
clf = WRFOClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
# Predict
y_pred = clf.predict(X_test)
print(f"Accuracy: {clf.score(X_test, y_test):.3f}")
Usage
Basic Classification
from wrfo import WRFOClassifier
# Initialize with default parameters
wrfo = WRFOClassifier(
n_estimators=100, # Number of trees
swarm_size=10, # PSO swarm size
max_iter=10, # PSO iterations
random_state=42 # For reproducibility
)
# Fit and predict
wrfo.fit(X_train, y_train)
predictions = wrfo.predict(X_test)
Advanced Configuration
wrfo = WRFOClassifier(
n_estimators=100,
swarm_size=10,
max_iter=10,
accuracy_weight=0.6, # Weight for accuracy in objective
diversity_weight=0.4, # Weight for diversity in objective
entropy_weight=0.1, # Weight for entropy in objective
val_split=0.2, # Validation split for PSO optimization
random_state=42,
n_jobs=-1, # Use all CPU cores
verbose=True # Print progress
)
Access Optimized Weights
# After fitting
print(f"Optimized weights: {wrfo.weights_}")
print(f"Trees selected: {sum(wrfo.weights_ > 0)}/{wrfo.n_estimators}")
print(f"Diversity matrix shape: {wrfo.divmat_.shape}")
Examples
See the examples/ directory for complete working examples:
iris_classification.py: Cross-validation example on Iris datasetcustom_dataset_example.py: Template for using WRFO with custom datasets
Run an example:
cd examples
python iris_classification.py
How It Works
WRFO improves upon standard Random Forest through three key steps:
- Train Base Ensemble: Creates a Random Forest with n_estimators trees
- Compute Diversity Matrix: Calculates pairwise Cohen's kappa diversity between all trees
- Optimize Weights: Uses PSO to find optimal tree weights that maximize:
- Classification accuracy (F1-score)
- Ensemble diversity (1 - weighted kappa)
- Feature entropy (Shannon entropy of root features)
The multi-objective optimization allows WRFO to select a diverse, accurate subset of trees while maintaining interpretability through feature diversity.
Algorithm Parameters
| Parameter | Default | Description |
|---|---|---|
n_estimators |
100 | Number of trees in the random forest |
swarm_size |
10 | Number of particles in PSO swarm |
max_iter |
10 | Maximum PSO iterations |
accuracy_weight |
0.6 | Weight for accuracy component |
diversity_weight |
0.4 | Weight for diversity component |
entropy_weight |
0.1 | Weight for entropy component |
val_split |
0.2 | Validation split ratio for PSO |
random_state |
None | Random seed for reproducibility |
n_jobs |
-1 | Number of parallel jobs |
verbose |
True | Whether to print progress |
Citation
If you use WRFO in your research, please cite:
@article{wrfo2026,
title={Towards Better Random Forests with Tree Weighting, Accuracy and Diversity-Preserving Pruning},
author={Nour Elislem Karabadji, Ali Assi, Abdelaziz Amara Korba, Ahmed Abdulaziz Al Nuaim, Hassina Seridi, Mohamed Elati, Wajdi Dhifli},
journal={Expert Systems with Applications},
year={2026},
note={Under revision}
}
License
MIT License - see LICENSE file for details
Acknowledgments
- Built on top of scikit-learn's RandomForestClassifier
- PSO implementation from pyswarm package
- Developed for research in ensemble learning optimization
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file wrfo-1.0.2.tar.gz.
File metadata
- Download URL: wrfo-1.0.2.tar.gz
- Upload date:
- Size: 13.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e7d8193c6803146e5ab1e5cba152cc9d1cfb1c2f68fc14c91cd6ea3c33c53327
|
|
| MD5 |
d509cf5890f03bc3863a953f6cf38179
|
|
| BLAKE2b-256 |
156b1f7076d3cf0ae5062dae65902b80826675b69cae71b292a5e7509ee86084
|
File details
Details for the file wrfo-1.0.2-py3-none-any.whl.
File metadata
- Download URL: wrfo-1.0.2-py3-none-any.whl
- Upload date:
- Size: 12.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2118c1ecd0b2f73ce674cc6b9db16319c6ed86139dd029ddd62ed24771ca1a6c
|
|
| MD5 |
5d45854795f11a471970fe30a28ee870
|
|
| BLAKE2b-256 |
f77f2e21a263acd9615ea7269f47e3601ee0e437ae217764f0180fb1f63020a5
|