Skip to main content

Tree Weighting, Accuracy and Diversity-Preserving Pruning for Random Forests

Project description

WRFO: Weighted Random Forest Optimization

WRFO (Weighted Random Forest Optimization) is a machine learning algorithm that enhances Random Forest classifiers through tree weighting, accuracy and diversity-preserving pruning using Particle Swarm Optimization (PSO). By optimizing tree weights based on a multi-objective function combining accuracy, diversity, and feature entropy, WRFO achieves improved performance while using fewer trees.

Paper: "Towards Better Random Forests with Tree Weighting, Accuracy and Diversity-Preserving Pruning" (Expert Systems with Applications, under revision)

Key Features

  • Multi-Objective Optimization: Balances accuracy, ensemble diversity, and feature entropy
  • Adaptive Tree Selection: Automatically identifies and weights the most valuable trees
  • Scikit-learn Compatible: Drop-in replacement for sklearn's RandomForestClassifier
  • Parallel Processing: Efficient computation using joblib parallelization
  • Research-Validated: under revision in Expert Systems with Applications

Installation

From source

git clone https://github.com/yourusername/WRFO.git
cd WRFO
pip install -r requirements.txt
pip install -e .

Requirements

  • Python >= 3.7
  • numpy >= 1.20.0
  • pandas >= 1.3.0
  • scikit-learn >= 1.0.0
  • pyswarm >= 0.6
  • scipy >= 1.7.0
  • joblib >= 1.0.0
  • tqdm >= 4.62.0

Quick Start

from wrfo import WRFOClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load data
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

# Train WRFO
clf = WRFOClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Predict
y_pred = clf.predict(X_test)
print(f"Accuracy: {clf.score(X_test, y_test):.3f}")

Usage

Basic Classification

from wrfo import WRFOClassifier

# Initialize with default parameters
wrfo = WRFOClassifier(
    n_estimators=100,      # Number of trees
    swarm_size=10,         # PSO swarm size
    max_iter=10,           # PSO iterations
    random_state=42        # For reproducibility
)

# Fit and predict
wrfo.fit(X_train, y_train)
predictions = wrfo.predict(X_test)

Advanced Configuration

wrfo = WRFOClassifier(
    n_estimators=100,
    swarm_size=10,
    max_iter=10,
    accuracy_weight=0.6,    # Weight for accuracy in objective
    diversity_weight=0.4,   # Weight for diversity in objective
    entropy_weight=0.1,     # Weight for entropy in objective
    val_split=0.2,          # Validation split for PSO optimization
    random_state=42,
    n_jobs=-1,              # Use all CPU cores
    verbose=True            # Print progress
)

Access Optimized Weights

# After fitting
print(f"Optimized weights: {wrfo.weights_}")
print(f"Trees selected: {sum(wrfo.weights_ > 0)}/{wrfo.n_estimators}")
print(f"Diversity matrix shape: {wrfo.divmat_.shape}")

Examples

See the examples/ directory for complete working examples:

  • iris_classification.py: Cross-validation example on Iris dataset
  • custom_dataset_example.py: Template for using WRFO with custom datasets

Run an example:

cd examples
python iris_classification.py

How It Works

WRFO improves upon standard Random Forest through three key steps:

  1. Train Base Ensemble: Creates a Random Forest with n_estimators trees
  2. Compute Diversity Matrix: Calculates pairwise Cohen's kappa diversity between all trees
  3. Optimize Weights: Uses PSO to find optimal tree weights that maximize:
    • Classification accuracy (F1-score)
    • Ensemble diversity (1 - weighted kappa)
    • Feature entropy (Shannon entropy of root features)

The multi-objective optimization allows WRFO to select a diverse, accurate subset of trees while maintaining interpretability through feature diversity.

Algorithm Parameters

Parameter Default Description
n_estimators 100 Number of trees in the random forest
swarm_size 10 Number of particles in PSO swarm
max_iter 10 Maximum PSO iterations
accuracy_weight 0.6 Weight for accuracy component
diversity_weight 0.4 Weight for diversity component
entropy_weight 0.1 Weight for entropy component
val_split 0.2 Validation split ratio for PSO
random_state None Random seed for reproducibility
n_jobs -1 Number of parallel jobs
verbose True Whether to print progress

Citation

If you use WRFO in your research, please cite:

@article{wrfo2026,
  title={Towards Better Random Forests with Tree Weighting, Accuracy and Diversity-Preserving Pruning},
  author={Nour Elislem Karabadji, Ali Assi, Abdelaziz Amara Korba, Ahmed Abdulaziz Al Nuaim, Hassina Seridi, Mohamed Elati, Wajdi Dhifli},
  journal={Expert Systems with Applications},
  year={2026},
  note={Under revision}
}

License

MIT License - see LICENSE file for details

Acknowledgments

  • Built on top of scikit-learn's RandomForestClassifier
  • PSO implementation from pyswarm package
  • Developed for research in ensemble learning optimization

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wrfo-1.0.0.tar.gz (13.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

wrfo-1.0.0-py3-none-any.whl (12.6 kB view details)

Uploaded Python 3

File details

Details for the file wrfo-1.0.0.tar.gz.

File metadata

  • Download URL: wrfo-1.0.0.tar.gz
  • Upload date:
  • Size: 13.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for wrfo-1.0.0.tar.gz
Algorithm Hash digest
SHA256 d1675751149078a952d7a6de9eb537fdf7f75837fe9698cf6525e0ab03cdc67d
MD5 5c91352ff3945dd68bfa431687298089
BLAKE2b-256 04867550b415205975d4a36ce267edfc9d987ea45e9cac2fe22c72ef2343b663

See more details on using hashes here.

File details

Details for the file wrfo-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: wrfo-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 12.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for wrfo-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2598b4fb27a9d2751e903e1de937e493b4e3adbe7fe2b66da882d7d684c51170
MD5 110aef3348a8a20766e350087fc642eb
BLAKE2b-256 7fa5d9a737b7bb8cb653ad9e968ff75552cb8847a1a075d62e38d4a87f178f27

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page