A federated Random Forest implementation
Project description
fed_rf_mk
Federated Random Forests and XGBoost with PySyft
Privacy-preserving federated learning for tree-based models (RandomForest, XGBoost) using PySyft. Includes orchestration client, datasite server utilities, remote training/evaluation functions, model aggregation, and explainability (SHAP, PFI).
Table of Contents
- What’s New
- Overview
- Repository Structure
- Installation
- Quickstart
- Detailed Usage
- Parameters
- RandomForest example
- XGBoost example
- Data
- Models & Experiments
- Outputs & Logging
- Extending the Project
- Troubleshooting / FAQ
- Changelog
- Citation / License / Acknowledgements
- To‑Verify
What’s New
- Modularized architecture:
remote_tasks.pyfor remoteml_experimentandevaluate_global_model.- Orchestrator modules for datasites and weights.
- Handlers for explainability (SHAP, PFI).
- Aggregation for RF and evaluation-time ensembling for XGB.
- Configurable explainability:
- Toggle SHAP/PFI and tune
shap_sample_size,pfi_n_repeats.
- Toggle SHAP/PFI and tune
- Nested hyperparameters:
- Use
model_paramsfor model-specific settings (e.g., XGB’slearning_rate,max_depth, device).
- Use
- Safer utilities:
- Better logging and guards for datasite IO and request status.
- Tests + CI:
- Unit tests for weights, aggregation, remote tasks, plus pre-commit hooks and GitHub Actions.
Overview
This project demonstrates a practical federated learning workflow on top of PySyft. It supports:
- Training RF/XGB on multiple datasites without moving raw data.
- Aggregating models (RF) and evaluating ensembles (XGB).
- Explainability across silos via SHAP and PFI with weighted averaging.
High-level flow:
flowchart LR
A[FLClient] --> B[DataSiteManager]
B -->|Syft requests| C[(Datasites)]
C -->|Run| D[remote_tasks.ml_experiment]
C -->|Run| E[remote_tasks.evaluate_global_model]
A --> F[WeightManager]
A --> G[ModelAggregator]
A --> H[SHAPHandler]
A --> I[PFIHandler]
D -->|return model + analysis| G
D -->|return SHAP| H
D -->|return PFI| I
G -->|merged model/ensemble| A
H -->|averaged SHAP| A
I -->|averaged PFI| A
Repository Structure
package-pysyft/client.py— FL client orchestrator (connects to datasites, coordinates training/evaluation).package-pysyft/server.py— server utilities (spawn datasites, manage approvals).package-pysyft/datasites.py— launch datasites, upload CSV asset, server-side policy toggles.package-pysyft/remote_tasks.py— remote functions:ml_experiment(data, dataParams, modelParams)evaluate_global_model(data, dataParams, modelParams)
package-pysyft/orchestrator/— orchestrator modules:clients.py—DataSiteManager(connect/send requests/check status).weights.py— robustWeightManager(normalization with None/negative handling).
package-pysyft/aggregation/models.py—ModelAggregator(RF estimator merge, XGB ensemble members).package-pysyft/handlers/—SHAPHandler,PFIHandler.package-pysyft/analysis/— notebooks and pipelines, e.g.,2-testing_aids_clinical.ipynb.tests/— unit tests for weights, aggregation, remote tasks..github/workflows/ci.yml— lint + tests CI.pyproject.toml— build config and dependencies.
Installation
Requirement: Python 3.10+
Install dev (recommended for running examples and tests):
python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate
pip install --upgrade pip
pip install .[dev]
pre-commit install
Minimal install:
pip install .
Key runtime deps:
- syft==0.9.1
- pandas>=2.0
- scikit-learn>=1.3
- cloudpickle>=3.0
- shap>=0.44
- xgboost>=2.0
Quickstart
Start two train datasites and one eval datasite (example ports/paths):
from package_pysyft.server import FLServer
import threading
servers = [
FLServer("silo1", 8080, "train_datasets/aids_clinical/part_0.csv", auto_accept=True, analysis_allowed=True),
FLServer("silo2", 8081, "train_datasets/aids_clinical/part_1.csv", auto_accept=True, analysis_allowed=True),
FLServer("eval_silo", 8082, "train_datasets/aids_clinical/part_2.csv", auto_accept=True, analysis_allowed=True),
]
for s in servers:
threading.Thread(target=s.start, daemon=True).start()
Train and evaluate with the client:
from package_pysyft.client import FLClient
rf_client = FLClient()
rf_client.add_train_client("silo1", "http://localhost:8080", "fedlearning@rf.com", "****", weight=0.6)
rf_client.add_train_client("silo2", "http://localhost:8081", "fedlearning@rf.com", "****", weight=0.4)
rf_client.add_eval_client("eval_silo", "http://localhost:8082", "fedlearning@rf.com", "****")
dataParams = {
"target": "cid",
"ignored_columns": ["cid"]
}
# RandomForest example (top-level hyperparams for RF)
modelParams = {
"model": None,
"model_type": "rf",
"n_base_estimators": 100,
"n_incremental_estimators": 20,
"train_size": 0.8,
"test_size": 0.8,
"fl_epochs": 1,
}
rf_client.set_data_params(dataParams)
rf_client.set_model_params(modelParams)
rf_client.run_model()
results = rf_client.run_evaluate()
print(results)
Detailed Usage
Parameters
Data Parameters (dataParams)
| Name | Type | Default | Description | Example |
|---|---|---|---|---|
| target | str | required | Target column name | "cid" |
| ignored_columns | list[str] | [] | Columns excluded from features | ["cid","id"] |
Top-level Model Parameters (modelParams)
| Name | Type | Default | Description | Example |
|---|---|---|---|---|
| model | bytes or None | None | Serialized seed model (None for cold start) | None |
| model_type | str | "rf" | "rf" or "xgb" | "xgb" |
| n_base_estimators | int | required | Base number of trees | 200 |
| n_incremental_estimators | int | 0 | Extra trees for warm-start | 50 |
| train_size | float | 0.8 | Train split ratio on each site | 0.7 |
| test_size | float | 0.2 | Eval split ratio (use 1.0 for eval silos) | 1.0 |
| fl_epochs | int | 1 | Global rounds | 1 |
| allow_analysis | bool | False | Gate SHAP/PFI analysis (and server policy) | True |
| analysis | dict | {} | Fine-grained analysis config (see below) | {"enabled":True,...} |
| model_params | dict | {} | Model-specific hyperparameters (nested) | {"max_depth":6,...} |
RandomForest Hyperparameters (modelParams)
| Name | Type | Default | Description | Example |
|---|---|---|---|---|
| criterion | str | "gini" | Split criterion | "entropy" |
| max_depth | int or None | None | Max tree depth | 12 |
| min_samples_split | int | 2 | Min samples to split | 2 |
| min_samples_leaf | int | 1 | Min samples per leaf | 1 |
| max_features | str or int | "sqrt" | Features per split | "sqrt" |
| bootstrap | bool | True | Bootstrap samples | True |
| n_jobs | int or None | None | Threads | -1 |
| … | … | … | Additional sklearn RF kwargs supported | … |
XGBoost Hyperparameters (modelParams)
| Name | Type | Default | Description | Example |
|---|---|---|---|---|
| device | str | "cpu" | "cpu" or "cuda" | "cuda" |
| verbosity | int | 1 | 0–3 | 1 |
| validate_parameters | bool | True | Validate input params | True |
| disable_default_eval_metric | bool | False | Disable default metric | False |
| learning_rate | float | 0.1 | Eta | 0.05 |
| max_depth | int | 6 | Max depth | 8 |
| min_child_weight | float | 1 | Min child weight | 1 |
| gamma | float | 0 | Min split loss | 0 |
| subsample | float | 1.0 | Row subsampling | 0.8 |
| colsample_bytree | float | 1.0 | Col subsampling per tree | 0.8 |
| colsample_bylevel | float | 1.0 | Col subsampling per level | 1.0 |
| colsample_bynode | float | 1.0 | Col subsampling per node | 1.0 |
| reg_lambda | float | 1.0 | L2 regularization | 1.0 |
| reg_alpha | float | 0.0 | L1 regularization | 0.0 |
| tree_method | str | "auto" | "hist","approx","exact","auto" | "hist" |
| max_delta_step | float | 0 | Max delta step | 0 |
| scale_pos_weight | float | 1.0 | Class imbalance control | 1.0 |
| booster | str | "gbtree" | Booster type | "gbtree" |
| grow_policy | str | "depthwise" | "depthwise" or "lossguide" | "depthwise" |
| max_leaves | int | 0 | Max leaves | 0 |
| max_bin | int | 256 | Histogram bins | 256 |
| sampling_method | str | "uniform" | "uniform" or "gradient_based" | "uniform" |
| Optional | various | N/A | updater, refresh_leaf, process_type, num_parallel_tree, monotone_constraints, interaction_constraints, multi_strategy |
see XGB docs |
RandomForest example
from package_pysyft.client import FLClient
rf_client = FLClient()
# connect clients (see Quickstart)
# ...
dataParams = {"target": "cid", "ignored_columns": ["cid"]}
modelParams = {
"model": None,
"model_type": "rf",
"n_base_estimators": 150,
"n_incremental_estimators": 30,
"train_size": 0.8,
"test_size": 0.8,
"fl_epochs": 1,
"allow_analysis": True,
"max_depth": 10,
"max_features": "sqrt",
"bootstrap": True,
}
rf_client.set_data_params(dataParams)
rf_client.set_model_params(modelParams)
rf_client.run_model()
print(rf_client.run_evaluate())
XGBoost example
from package_pysyft.client import FLClient
xgb_client = FLClient()
# connect clients (see Quickstart)
# ...
dataParams = {"target": "cid", "ignored_columns": ["cid"]}
modelParams = {
"model": None,
"model_type": "xgb",
"n_base_estimators": 200,
"n_incremental_estimators": 50,
"train_size": 0.8,
"test_size": 0.8,
"fl_epochs": 1,
"allow_analysis": True,
"device": "cpu", # or "cuda"
"learning_rate": 0.1,
"max_depth": 8,
"subsample": 0.8,
"colsample_bytree": 0.8,
"tree_method": "hist",
"verbosity": 1,
"validate_parameters": True
}
xgb_client.set_data_params(dataParams)
xgb_client.set_model_params(modelParams)
xgb_client.run_model()
print(xgb_client.run_evaluate())
Data
- Input: CSV files per datasite with consistent schema.
dataParams["target"]must exist in all files;ignored_columnsare dropped from features.datasites.pyuploads a single asset called "Asset" by default (first dataset/asset is used in examples).
Models & Experiments
- RandomForest:
- Training: local RF per silo; aggregation merges estimators proportionally to normalized weights.
- Warm-start:
n_incremental_estimatorsgrows trees for subsequent rounds.
- XGBoost:
- Training: XGB per silo; for continuation, the booster is reused and
n_estimatorsis increased. - Evaluation: weighted margin (logit) averaging of ensemble members; picks best-weight seed for next round.
- Training: XGB per silo; for continuation, the booster is reused and
Explainability:
- SHAP: mean absolute SHAP per feature; averaged across silos using normalized weights.
- PFI: permutation importance on eval split; mean/std per feature; averaged across silos.
Outputs & Logging
- Remote training returns:
"model"(serialized),"model_type", sizes, and optionally"shap_data","pfi_data".
- Client logs progress; datasites log request approvals. Arrays in SHAP/PFI are normalized to Python floats for downstream use.
Extending the Project
- Add new model types by:
- Updating
remote_tasks.ml_experimentto construct and train the model. - Extending
ModelAggregatorfor aggregation/ensembling semantics.
- Updating
- Add new explainability methods through handlers similar to SHAP/PFI.
Troubleshooting / FAQ
- Evaluate fails with “Reference don’t match: dict”
- Ensure evaluation function is available on datasite. The client has a fallback that creates a syft function ad-hoc if missing; confirm your server auto-accepts or approve requests manually.
- Missing packages
- Install with
pip install .[dev]to get all deps (including shap/xgboost).
- Install with
- SHAP/PFI slow
- Reduce
analysis.shap_sample_sizeoranalysis.pfi_n_repeats, or disable viaanalysis.enabled=False.
- Reduce
Changelog
- 0.2.0 (2025-09-06)
- Modular orchestrator/handlers/aggregation
- Nested
model_paramssupport for RF/XGB - Explainability controls (SHAP/PFI) and parameters
- Safer datasite loading + logging, fallback eval registration
- Tests + CI, pre-commit hooks
Citation / License / Acknowledgements
- License: MIT
- Built with:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file fed_rf_mk-0.0.7rc1.tar.gz.
File metadata
- Download URL: fed_rf_mk-0.0.7rc1.tar.gz
- Upload date:
- Size: 30.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
69bd2eb925da0b11ba5d2a561a961cfcb495c674b8c0ecf34a4bb3b5665afb8c
|
|
| MD5 |
d3b58d1af4a33acfda10fae779224786
|
|
| BLAKE2b-256 |
783ce1400659169d06a78bc7d6b78e8bd1d9d9a0f18bf0277eb4146988db7b97
|
File details
Details for the file fed_rf_mk-0.0.7rc1-py3-none-any.whl.
File metadata
- Download URL: fed_rf_mk-0.0.7rc1-py3-none-any.whl
- Upload date:
- Size: 30.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1195c0d7d2534c5563cf0ddfbf6fb1fc9f60d0c52ade80db87868c84225194cf
|
|
| MD5 |
2718029cd79f090cd113de5d89c636b9
|
|
| BLAKE2b-256 |
48402f923ff360a55c2e6f04850b9fd7055a6d2966c0b05599b897e86944f999
|