Implementation of Federated Random Survival Forest for partially overlapping data.
Project description
Federated Random Survival Forest for Partially overlapping Data
federated-rsf is a python implementation of the Federated Random Survival Forest algorithm for partially overlapping data.
Overview
Federated Random Survival Forest (federated-rsf) enables training random survival forest models across multiple institutions without sharing raw data. It is designed for partially overlapping feature spaces and privacy-sensitive biomedical datasets.
Features
- Federated survival random forest training
- Support for partially overlapping feature spaces
- Compatible with scikit-survival data structures and evaluation methods
- Privacy-preserving model aggregation
Installation
Dependencies
federated-rsf requires:
- numpy (>=2.0.0)
- pandas (>=2.3.0)
- scikit-learn(>=1.8.0)
- scikit-survival (>=0.27.0)
User installation
The easiest way to install federated-rsf is using pip
pip install -U federated-rsf
To install in editable mode, clone the repository and then install it using pip
git https://github.com/HauschildLab/FRSF4POD.git
cd FRSF4POD
pip install -U .
To install in editable mode it with optional testing or development libraries uses
pip install -U -e .[dev]
or
pip install -U -e .[test]
Quick Start
federated-rsf uses three main steps to train federated models.
- First the local data schema of all the clients has to be unified into a global schema to facilitate the aggregation of models.
- Second is the training of the local-rsf models on the local data
- Third is the aggregation and distribution of the local estimators from the clients.
from federated_rsf.models import (
FederatedRandomSurvivalForest,
LocalRandomSurvivalForest,
)
from federated_rsf.preprocessing import SchemaAligner, SchemaCreator
from federated_rsf.testing import create_dummy_data, federate_data
In this example we create a dummy dataset using the testing module. This module can be used valiate the federated learning pipeline is case of missing access to the actual data.
# Parameters
n_samples = 500
n_features = 10
n_clients = 5
random_state = 0
# Create Dummy Dataset
X, y = create_dummy_data(
n_samples,
n_features,
random_state=random_state,
)
# Split Dataset samples up to all clients
X_list, y_list = federate_data(
X,
y,
n_clients,
drop_feature_percentage=0.33,
random_state=random_state,
)
Next the columns of the local datasets are aligned to a global schema using the SchemaCreator and the local
# Create global Schema
schema_creator = SchemaCreator(anonymize=False)
local_columns = [DatasetSchema(X_local.columns) for X_local in X_list]
dataset_schemas = schema_creator.fit_transform(local_columns)
# Align local datasets
X_list_aligned = []
for X_local, schema in zip(X_list, dataset_schemas):
schema_aligner = SchemaAligner()
X_aligned = schema_aligner.fit_transform(X_local, schema)
X_list_aligned.append(X_aligned)
The local models can then be trained on the processed local data.
# Train local models
local_models: list[LocalRandomSurvivalForest] = []
for X_local, y_local in zip(X_list_aligned, y_list):
local_model = LocalRandomSurvivalForest(
random_state=random_state,
)
local_model = local_model.fit(X_local, y_local)
local_models.append(local_model)
The trained local models are then aggregated and the estimators are redistributed using the federated model.
# Distribute trees between local models
fed_model = FederatedRandomSurvivalForest(local_models=local_models)
fed_model.distribute_trees()
Lastly you can compare the local and the federated model performance for example using the predict predict_survival_function and predict_cumulative_hazard_function
# Example visualization of survival function and cumulative hazard function
client_index = 0
n_lines = 5
survival_local = local_models[client_index].predict_survival_function(
X_list_aligned[client_index]
)
hazard_local = local_models[client_index].predict_cumulative_hazard_function(
X_list_aligned[client_index]
)
local_models[client_index].use_federated_estimators()
survival_federated = local_models[client_index].predict_survival_function(
X_list_aligned[client_index]
)
hazard_federated = local_models[client_index].predict_cumulative_hazard_function(
X_list_aligned[client_index]
)
from matplotlib import pyplot as plt
for surv in [survival_local, survival_federated]:
for i, s in enumerate(surv[:n_lines]):
plt.step(s.x, s.y, where="post", label=str(i))
plt.ylabel("Survival probability")
plt.xlabel("Time in days")
plt.legend()
plt.grid(True)
plt.show()
for hazard in [hazard_local, hazard_federated]:
for i, s in enumerate(hazard[:n_lines]):
plt.step(s.x, s.y, where="post", label=str(i))
plt.ylabel("Cumulative hazard")
plt.xlabel("Time in days")
plt.legend()
plt.grid(True)
plt.show()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file federated_rsf-0.1.2.tar.gz.
File metadata
- Download URL: federated_rsf-0.1.2.tar.gz
- Upload date:
- Size: 185.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8e48284fa1a56318f1a2636c59b7647f32c5ff210f0dbaeed25277d58079c786
|
|
| MD5 |
34a2e74adbdf9ecba14da668e8d5b5b9
|
|
| BLAKE2b-256 |
fb8b622e6de8ee96a8ad657e8d34ea803ed01baa5d754f8f62c2a35a151073ff
|
Provenance
The following attestation bundles were made for federated_rsf-0.1.2.tar.gz:
Publisher:
publish.yml on HauschildLab/FRSF4POD
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
federated_rsf-0.1.2.tar.gz -
Subject digest:
8e48284fa1a56318f1a2636c59b7647f32c5ff210f0dbaeed25277d58079c786 - Sigstore transparency entry: 1439304429
- Sigstore integration time:
-
Permalink:
HauschildLab/FRSF4POD@2d9641fafc4113b76c883be4c441b3a636f501cd -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/HauschildLab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@2d9641fafc4113b76c883be4c441b3a636f501cd -
Trigger Event:
release
-
Statement type:
File details
Details for the file federated_rsf-0.1.2-py3-none-any.whl.
File metadata
- Download URL: federated_rsf-0.1.2-py3-none-any.whl
- Upload date:
- Size: 15.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4fc81e30f362025f02b434cd0f890ce19edee3dbb530ead38d47d245333cab7a
|
|
| MD5 |
49c8b074343192e6de01a3b8a024c38a
|
|
| BLAKE2b-256 |
ddef4f907ac2f02b1eb78937fb704caf4207102c3eb76a1438450229d2ad56d7
|
Provenance
The following attestation bundles were made for federated_rsf-0.1.2-py3-none-any.whl:
Publisher:
publish.yml on HauschildLab/FRSF4POD
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
federated_rsf-0.1.2-py3-none-any.whl -
Subject digest:
4fc81e30f362025f02b434cd0f890ce19edee3dbb530ead38d47d245333cab7a - Sigstore transparency entry: 1439304435
- Sigstore integration time:
-
Permalink:
HauschildLab/FRSF4POD@2d9641fafc4113b76c883be4c441b3a636f501cd -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/HauschildLab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@2d9641fafc4113b76c883be4c441b3a636f501cd -
Trigger Event:
release
-
Statement type: