Skip to main content

Tree Weighting, Accuracy and Diversity-Preserving Pruning for Random Forests

Project description

WRFO: Weighted Random Forest Optimization

WRFO (Weighted Random Forest Optimization) is a machine learning algorithm that enhances Random Forest classifiers through tree weighting, accuracy and diversity-preserving pruning using Particle Swarm Optimization (PSO). By optimizing tree weights based on a multi-objective function combining accuracy, diversity, and feature entropy, WRFO achieves improved performance while using fewer trees.

Paper: "Towards Better Random Forests with Tree Weighting, Accuracy and Diversity-Preserving Pruning" (Expert Systems with Applications, under revision)

Key Features

  • Multi-Objective Optimization: Balances accuracy, ensemble diversity, and feature entropy
  • Adaptive Tree Selection: Automatically identifies and weights the most valuable trees
  • Scikit-learn Compatible: Drop-in replacement for sklearn's RandomForestClassifier
  • Parallel Processing: Efficient computation using joblib parallelization
  • Research-Validated: under revision in Expert Systems with Applications

Installation

From source

git clone https://github.com/wajdidhifli/WRFO.git
cd WRFO
pip install -r requirements.txt
pip install -e .

Requirements

  • Python >= 3.7
  • numpy >= 1.20.0
  • pandas >= 1.3.0
  • scikit-learn >= 1.0.0
  • pyswarm >= 0.6
  • scipy >= 1.7.0
  • joblib >= 1.0.0
  • tqdm >= 4.62.0

Quick Start

from wrfo import WRFOClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load data
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

# Train WRFO
clf = WRFOClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Predict
y_pred = clf.predict(X_test)
print(f"Accuracy: {clf.score(X_test, y_test):.3f}")

Usage

Basic Classification

from wrfo import WRFOClassifier

# Initialize with default parameters
wrfo = WRFOClassifier(
    n_estimators=100,      # Number of trees
    swarm_size=10,         # PSO swarm size
    max_iter=10,           # PSO iterations
    random_state=42        # For reproducibility
)

# Fit and predict
wrfo.fit(X_train, y_train)
predictions = wrfo.predict(X_test)

Advanced Configuration

wrfo = WRFOClassifier(
    n_estimators=100,
    swarm_size=10,
    max_iter=10,
    accuracy_weight=0.6,    # Weight for accuracy in objective
    diversity_weight=0.4,   # Weight for diversity in objective
    entropy_weight=0.1,     # Weight for entropy in objective
    val_split=0.2,          # Validation split for PSO optimization
    random_state=42,
    n_jobs=-1,              # Use all CPU cores
    verbose=True            # Print progress
)

Access Optimized Weights

# After fitting
print(f"Optimized weights: {wrfo.weights_}")
print(f"Trees selected: {sum(wrfo.weights_ > 0)}/{wrfo.n_estimators}")
print(f"Diversity matrix shape: {wrfo.divmat_.shape}")

Examples

See the examples/ directory for complete working examples:

  • iris_classification.py: Cross-validation example on Iris dataset
  • custom_dataset_example.py: Template for using WRFO with custom datasets

Run an example:

cd examples
python iris_classification.py

How It Works

WRFO improves upon standard Random Forest through three key steps:

  1. Train Base Ensemble: Creates a Random Forest with n_estimators trees
  2. Compute Diversity Matrix: Calculates pairwise Cohen's kappa diversity between all trees
  3. Optimize Weights: Uses PSO to find optimal tree weights that maximize:
    • Classification accuracy (F1-score)
    • Ensemble diversity (1 - weighted kappa)
    • Feature entropy (Shannon entropy of root features)

The multi-objective optimization allows WRFO to select a diverse, accurate subset of trees while maintaining interpretability through feature diversity.

Algorithm Parameters

Parameter Default Description
n_estimators 100 Number of trees in the random forest
swarm_size 10 Number of particles in PSO swarm
max_iter 10 Maximum PSO iterations
accuracy_weight 0.6 Weight for accuracy component
diversity_weight 0.4 Weight for diversity component
entropy_weight 0.1 Weight for entropy component
val_split 0.2 Validation split ratio for PSO
random_state None Random seed for reproducibility
n_jobs -1 Number of parallel jobs
verbose True Whether to print progress

Citation

If you use WRFO in your research, please cite:

@article{wrfo2026,
  title={Towards Better Random Forests with Tree Weighting, Accuracy and Diversity-Preserving Pruning},
  author={Nour Elislem Karabadji, Ali Assi, Abdelaziz Amara Korba, Ahmed Abdulaziz Al Nuaim, Hassina Seridi, Mohamed Elati, Wajdi Dhifli},
  journal={Expert Systems with Applications},
  year={2026},
  note={Under revision}
}

License

MIT License - see LICENSE file for details

Acknowledgments

  • Built on top of scikit-learn's RandomForestClassifier
  • PSO implementation from pyswarm package
  • Developed for research in ensemble learning optimization

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wrfo-1.0.2.tar.gz (13.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

wrfo-1.0.2-py3-none-any.whl (12.6 kB view details)

Uploaded Python 3

File details

Details for the file wrfo-1.0.2.tar.gz.

File metadata

  • Download URL: wrfo-1.0.2.tar.gz
  • Upload date:
  • Size: 13.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for wrfo-1.0.2.tar.gz
Algorithm Hash digest
SHA256 e7d8193c6803146e5ab1e5cba152cc9d1cfb1c2f68fc14c91cd6ea3c33c53327
MD5 d509cf5890f03bc3863a953f6cf38179
BLAKE2b-256 156b1f7076d3cf0ae5062dae65902b80826675b69cae71b292a5e7509ee86084

See more details on using hashes here.

File details

Details for the file wrfo-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: wrfo-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 12.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for wrfo-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2118c1ecd0b2f73ce674cc6b9db16319c6ed86139dd029ddd62ed24771ca1a6c
MD5 5d45854795f11a471970fe30a28ee870
BLAKE2b-256 f77f2e21a263acd9615ea7269f47e3601ee0e437ae217764f0180fb1f63020a5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page