Federated learning with tree-based models (Random Forest, XGBoost, GBT) via Flower
Project description
flwr-trees
Federated learning with tree-based models via Flower.
flwr-trees is a Python library that exposes scikit-learn-compatible federated estimators for Random Forests, XGBoost, and Gradient Boosted Trees, backed by the Flower federated learning framework. All estimators satisfy the scikit-learn BaseEstimator contract and pass check_estimator(), making them drop-in replacements inside any existing sklearn.pipeline.Pipeline.
Status: Alpha (
v0.1.0). The public API is stable for all implemented estimators.
Table of Contents
- Overview
- Installation
- Quick Start
- Architecture
- Public API
- Benchmark Results
- Progress
- Remaining Work
- Development
- Citation
- License
Overview
flwr-trees simulates federated training locally: data is partitioned among n_clients virtual clients, and each client trains on its local shard. The resulting models are aggregated according to the chosen FL strategy.
| Strategy | Description | Estimators |
|---|---|---|
FedForestBagging |
Each client trains an independent Random Forest; all trees are pooled at the server. | FederatedRandomForestClassifier, FederatedRandomForestRegressor |
FedForestCyclic |
A single XGBoost Booster is passed round-robin through all clients; each adds boosting rounds to it. | FederatedXGBClassifier, FederatedXGBRegressor |
FedHistogramAggregation |
Research contribution — clients send per-feature split histograms instead of full trees in round 1. Server aggregates histograms to determine global split thresholds. Clients then bin features and train RFs on the discretised data. Reduces round-1 communication by up to 99.8% vs bagging. | FederatedHistogramRFClassifier, FederatedHistogramRFRegressor |
Bagging and histogram strategies are implemented using real Flower Strategy / NumPyClient / Parameters types with an in-process orchestrator (no Ray or separate server process required). Setting use_flower=True activates the Flower code path; use_flower=False (the default) runs an equivalent plain Python loop for testing and check_estimator().
Installation
Requires Python >= 3.11.
# Install with pip
pip install flwr-trees
# Install with uv (recommended)
uv sync
# Or with pip
pip install -e .
# With optional DP (differential privacy) support
pip install -e ".[privacy]"
Core runtime dependencies:
| Package | Minimum version |
|---|---|
flwr |
1.29.0 |
scikit-learn |
1.8.0 |
xgboost |
3.2.0 |
numpy |
2.4.4 |
Quick Start
Federated Random Forest
from flwr_trees import FederatedRandomForestClassifier
clf = FederatedRandomForestClassifier(
n_estimators=100, n_clients=5, n_rounds=3,
iid=False, alpha=0.5, random_state=42,
)
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test))
Federated XGBoost (cyclic boosting)
from flwr_trees import FederatedXGBClassifier
clf = FederatedXGBClassifier(
n_estimators=50, n_clients=5, n_rounds=2,
max_depth=6, learning_rate=0.1, random_state=42,
)
clf.fit(X_train, y_train)
# GPU acceleration (requires CUDA-capable GPU + XGBoost >= 2.0)
clf_gpu = FederatedXGBClassifier(
n_estimators=50, n_clients=5, n_rounds=2,
device="cuda", # or "cuda:0" for a specific device
random_state=42,
)
Federated Gradient Boosted Trees
from flwr_trees import FederatedGBTClassifier
clf = FederatedGBTClassifier(
n_estimators=100, n_clients=5,
max_depth=3, learning_rate=0.1, random_state=42,
)
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test))
Communication-efficient federated RF (histogram aggregation)
from flwr_trees import FederatedHistogramRFClassifier
clf = FederatedHistogramRFClassifier(
n_estimators=50, n_clients=10, n_rounds=3,
n_bins=32, iid=False, alpha=0.5,
use_flower=True, random_state=42,
)
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test))
# Inspect communication cost
print(clf.strategy_.bytes_sent_per_round) # [histogram_bytes, tree_bytes, ...]
print(clf.strategy_.bytes_saved_vs_bagging) # savings vs standard bagging
Simulating client dropout
from flwr_trees.simulation import simulate_clients, ClientDropoutWrapper
partitions = simulate_clients(X, y, n_clients=10, iid=True, random_state=0)
wrapper = ClientDropoutWrapper(partitions, dropout_rate=0.2, min_clients=3, random_state=42)
for round_idx in range(n_rounds):
active = wrapper.sample(round_idx) # reproducible per round
# train on `active` partitions this round
Differential privacy wrappers
from flwr_trees.privacy import NoisyHistogram, DPTreeWrapper
from sklearn.tree import DecisionTreeClassifier
# Add Laplace noise to histogram counts before sending
noisy_hist = NoisyHistogram(epsilon=1.0, random_state=0)
private_counts = noisy_hist.apply(raw_counts)
# Wrap a decision tree to add Gaussian noise to leaf predictions
tree = DecisionTreeClassifier(max_depth=3).fit(X_train, y_train)
dp_tree = DPTreeWrapper(tree, epsilon=1.0, random_state=0)
proba = dp_tree.predict_proba(X_test) # noisy, sums to 1
Architecture
flwr-trees/
├── src/
│ └── flwr_trees/
│ ├── __init__.py # Public API re-exports
│ ├── estimators/
│ │ ├── base.py # BaseFederatedTreeEstimator (ABC)
│ │ ├── rf.py # FederatedRandomForestClassifier/Regressor
│ │ ├── xgb.py # FederatedXGBClassifier/Regressor
│ │ ├── hist_rf.py # FederatedHistogramRFClassifier/Regressor
│ │ └── gbt.py # FederatedGBTClassifier/Regressor
│ ├── aggregation/
│ │ ├── bagging.py # FedForestBagging + FedForestBaggingClient
│ │ ├── cyclic.py # FedForestCyclic + XGBCyclicClient
│ │ └── histogram.py # FedHistogramAggregation + HistogramClient
│ ├── simulation/
│ │ ├── partitioning.py # simulate_clients, partition_noniid
│ │ └── dropout.py # ClientDropoutWrapper
│ ├── privacy/
│ │ └── dp.py # NoisyHistogram, DPTreeWrapper
│ └── compat/
│ └── array_api.py # get_array_namespace, to_numpy
└── tests/
├── aggregation/ # Strategy / client integration tests
├── compat/ # Array API compliance tests
├── estimators/ # Estimator unit + sklearn compliance tests
├── simulation/ # Partitioning and dropout tests
└── privacy/ # DP noise tests
estimators/ — Public sklearn-compatible estimators. Each branches at runtime between a local loop (use_flower=False) and a Flower-wired in-process loop (use_flower=True).
aggregation/ — Flower Strategy and NumPyClient implementations. Models are serialised as pickle-encoded uint8 NDArrays passed via Flower Parameters.
simulation/ — Data partitioning utilities. partition_noniid implements Dirichlet-based heterogeneous splits. ClientDropoutWrapper simulates per-round client unavailability.
privacy/ — Optional DP noise wrappers. NoisyHistogram adds Laplace noise to histogram counts. DPTreeWrapper adds Gaussian noise to tree leaf predictions.
compat/ — Array API Standard utilities ensuring estimators accept NumPy, CuPy, and PyTorch tensors.
Public API
Estimators
| Class | Type | Strategy | Key fitted attributes |
|---|---|---|---|
FederatedRandomForestClassifier |
Classifier | Bagging | estimators_: list[DecisionTreeClassifier] |
FederatedRandomForestRegressor |
Regressor | Bagging | estimators_: list[DecisionTreeRegressor] |
FederatedXGBClassifier |
Classifier | Cyclic | booster_: xgboost.Booster |
FederatedXGBRegressor |
Regressor | Cyclic | booster_: xgboost.Booster |
FederatedHistogramRFClassifier |
Classifier | Histogram | estimators_, thresholds_, strategy_ |
FederatedHistogramRFRegressor |
Regressor | Histogram | estimators_, thresholds_, strategy_ |
FederatedGBTClassifier |
Classifier | Bagging (per-client GBT) | estimators_: list[GradientBoostingClassifier] |
FederatedGBTRegressor |
Regressor | Bagging (per-client GBT) | estimators_: list[GradientBoostingRegressor] |
All estimators share the base parameters n_clients, n_rounds, and random_state from BaseFederatedTreeEstimator. RF and Histogram estimators additionally accept use_flower.
GPU acceleration via
device='cuda'is supported forFederatedXGBClassifierandFederatedXGBRegressoronly.FederatedRandomForest,FederatedHistogramRF, andFederatedGBTestimators use CPU-only sklearn backends.
Aggregation strategies
| Class | Key attributes | Notes |
|---|---|---|
FedForestBagging |
bytes_sent_per_round, trees_ |
One entry per round; all clients contribute |
FedForestCyclic |
bytes_sent_per_round, booster_ |
One entry per Flower step (one client) |
FedHistogramAggregation |
bytes_sent_per_round, bytes_saved_vs_bagging, trees_, thresholds_ |
Round 1 = histogram exchange; rounds 2+ = tree collection |
Simulation utilities
from flwr_trees.simulation import simulate_clients, partition_noniid, ClientDropoutWrapper
Privacy utilities
from flwr_trees.privacy import NoisyHistogram, DPTreeWrapper
Benchmark Results
Communication-efficiency comparison between FedForestBagging and FedHistogramAggregation across three datasets. Settings: n_estimators=20, n_clients=5, n_rounds=3, n_bins=32.
| Dataset | Method | Round-1 Bytes | Accuracy | Round-1 Savings |
|---|---|---|---|---|
| breast_cancer | FedForestBagging | 212 KB | 0.9737 | — |
| breast_cancer | FedHistogramAggregation | 116 KB | 0.9649 | 45.2% |
| synthetic_iid | FedForestBagging | 450 KB | 0.9100 | — |
| synthetic_iid | FedHistogramAggregation | 78 KB | 0.8800 | 82.8% |
| synthetic_noniid (alpha=0.3) | FedForestBagging | 400 KB | 0.6650 | — |
| synthetic_noniid (alpha=0.3) | FedHistogramAggregation | 78 KB | 0.7000 | 80.6% |
On the Adult dataset (48k samples, 14 features), FedHistogramAggregation achieves 99.8% round-1 savings (108 KB vs 48 MB) with comparable accuracy to FedForestBagging.
The full benchmark is in benchmarks/communication_benchmark.py (Figure 1 data) and benchmarks/real_world_benchmark.py (multi-dataset comparison).
Progress
| Phase | Description | Status | Tests |
|---|---|---|---|
| 1 | Core infrastructure: BaseFederatedTreeEstimator, array API compat, simulation/partitioning |
Complete | 22 |
| 2 | FederatedRandomForestClassifier and FederatedRandomForestRegressor |
Complete | 29 |
| 3 | Flower wiring for RF: FedForestBagging, FedForestBaggingClient, use_flower=True |
Complete | 5 |
| 4 | FederatedXGBClassifier and FederatedXGBRegressor with FedForestCyclic strategy |
Complete | 28 |
| 5 | FedHistogramAggregation + FederatedHistogramRFClassifier/Regressor (research contribution) |
Complete | 11 |
| 6 | FederatedGBTClassifier/Regressor, ClientDropoutWrapper, NoisyHistogram, DPTreeWrapper, real-world benchmarks |
Complete | 51 |
Total: 171 tests passing (pytest tests/ -v), including full sklearn.utils.estimator_checks.check_estimator() compliance for all eight estimators.
Remaining Work
The following items remain for future development:
- Full Array API compliance — The
compat/module handles NumPy and objects with.numpy()/.get()methods. Explicit CuPy and PyTorch tensor support requires additional testing. - Real distributed deployment — All estimators are simulation-only (in-process). End-to-end testing with a real multi-process Flower deployment is not yet covered.
FederatedGBTClassifierwith Flower path — Current GBT uses local-only training. Addinguse_flower=Truesupport withFedGBTBaggingstrategy would enable communication tracking for GBT.- Advanced privacy composition — Integration with
opacusfor training-level DP (currently only post-hoc noise wrappers exist).
Development
# Create and activate virtual environment
uv sync
.venv\Scripts\Activate.ps1 # Windows
source .venv/bin/activate # Linux / macOS
# Run the full test suite
pytest tests/ -v
# Verify sklearn estimator compliance for all estimators
python -c "
from sklearn.utils.estimator_checks import check_estimator
from flwr_trees import (
FederatedRandomForestClassifier, FederatedRandomForestRegressor,
FederatedXGBClassifier, FederatedXGBRegressor,
FederatedHistogramRFClassifier,
FederatedGBTClassifier, FederatedGBTRegressor,
)
for cls in [FederatedRandomForestClassifier, FederatedRandomForestRegressor,
FederatedXGBClassifier, FederatedXGBRegressor,
FederatedHistogramRFClassifier,
FederatedGBTClassifier, FederatedGBTRegressor]:
check_estimator(cls())
print(f'{cls.__name__}: PASSED')
"
# Run communication benchmark (Figure 1 data)
python benchmarks/communication_benchmark.py
# Run real-world benchmark (skip large HIGGS download)
python benchmarks/real_world_benchmark.py --skip-higgs
# Lint
ruff check src/
Citation
If you use flwr-trees in research, please cite:
@software{flwr_trees_2025,
author = {Mauktik},
title = {flwr-trees: Federated Learning with Tree-Based Models},
year = {2026},
url = {https://github.com/MAUK9086/flwr-trees},
}
License
Apache License 2.0. See pyproject.toml for full classifier metadata.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flwr_trees-0.1.0.tar.gz.
File metadata
- Download URL: flwr_trees-0.1.0.tar.gz
- Upload date:
- Size: 44.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.14 {"installer":{"name":"uv","version":"0.11.14","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
69c98464f725fb5e1dbeda5db2e36029e8274d95117e76e09cf79bf780a7e5df
|
|
| MD5 |
880b1d8f7fd7a20f63140ea159887c40
|
|
| BLAKE2b-256 |
2b096a86234e29911aab9a94b5843567c293e90220f337e1d41d83bcaadaa541
|
File details
Details for the file flwr_trees-0.1.0-py3-none-any.whl.
File metadata
- Download URL: flwr_trees-0.1.0-py3-none-any.whl
- Upload date:
- Size: 43.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.14 {"installer":{"name":"uv","version":"0.11.14","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6233ee83604f79f1a88878ac1b8a404295e0946819c7a8c6ffd739b5e6ad1ede
|
|
| MD5 |
9c362468e3b8095ec34635a12c7e0bb0
|
|
| BLAKE2b-256 |
44be5162ad65735fc429052f6ae905c213d116e8c3616ac0673ddb08295a838a
|