Skip to main content

Implementation of Federated Random Survival Forest for partially overlapping data.

Project description

Federated Random Survival Forest for Partially overlapping Data

federated-rsf is a python implementation of the Federated Random Survival Forest algorithm for partially overlapping data.

Overview

Federated Random Survival Forest (federated-rsf) enables training random survival forest models across multiple institutions without sharing raw data. It is designed for partially overlapping feature spaces and privacy-sensitive biomedical datasets.

Features

  • Federated survival random forest training
  • Support for partially overlapping feature spaces
  • Compatible with scikit-survival data structures and evaluation methods
  • Privacy-preserving model aggregation

Installation

Dependencies

federated-rsf requires:

  • numpy (>=2.0.0)
  • pandas (>=2.3.0)
  • scikit-learn(>=1.8.0)
  • scikit-survival (>=0.27.0)

User installation

The easiest way to install federated-rsf is using pip

pip install -U federated-rsf

To install in editable mode, clone the repository and then install it using pip

git https://github.com/HauschildLab/FRSF4POD.git
cd FRSF4POD
pip install -U .

To install in editable mode it with optional testing or development libraries uses

pip install -U -e .[dev]

or

pip install -U -e .[test]

Quick Start

federated-rsf uses three main steps to train federated models.

  • First the local data schema of all the clients has to be unified into a global schema to facilitate the aggregation of models.
  • Second is the training of the local-rsf models on the local data
  • Third is the aggregation and distribution of the local estimators from the clients.
from federated_rsf.models import (
    FederatedRandomSurvivalForest,
    LocalRandomSurvivalForest,
)
from federated_rsf.preprocessing import SchemaAligner, SchemaCreator
from federated_rsf.testing import create_dummy_data, federate_data

In this example we create a dummy dataset using the testing module. This module can be used valiate the federated learning pipeline is case of missing access to the actual data.

# Parameters
n_samples = 500
n_features = 10
n_clients = 5
random_state = 0

# Create Dummy Dataset
X, y = create_dummy_data(
    n_samples,
    n_features,
    random_state=random_state,
)

# Split Dataset samples up to all clients
X_list, y_list = federate_data(
    X,
    y,
    n_clients,
    drop_feature_percentage=0.33,
    random_state=random_state,
)

Next the columns of the local datasets are aligned to a global schema using the SchemaCreator and the local

# Create global Schema
schema_creator = SchemaCreator(anonymize=False)
local_columns = [DatasetSchema(X_local.columns) for X_local in X_list]
dataset_schemas = schema_creator.fit_transform(local_columns)

# Align local datasets
X_list_aligned = []

for X_local, schema in zip(X_list, dataset_schemas):
    schema_aligner = SchemaAligner()
    X_aligned = schema_aligner.fit_transform(X_local, schema)
    X_list_aligned.append(X_aligned)

The local models can then be trained on the processed local data.

# Train local models
local_models: list[LocalRandomSurvivalForest] = []

for X_local, y_local in zip(X_list_aligned, y_list):
    local_model = LocalRandomSurvivalForest(
        random_state=random_state,
    )
    local_model = local_model.fit(X_local, y_local)
    local_models.append(local_model)

The trained local models are then aggregated and the estimators are redistributed using the federated model.

# Distribute trees between local models
fed_model = FederatedRandomSurvivalForest(local_models=local_models)
fed_model.distribute_trees()

Lastly you can compare the local and the federated model performance for example using the predict predict_survival_function and predict_cumulative_hazard_function

# Example visualization of survival function and cumulative hazard function
client_index = 0
n_lines = 5

survival_local = local_models[client_index].predict_survival_function(
    X_list_aligned[client_index]
)

hazard_local = local_models[client_index].predict_cumulative_hazard_function(
    X_list_aligned[client_index]
)

local_models[client_index].use_federated_estimators()

survival_federated = local_models[client_index].predict_survival_function(
    X_list_aligned[client_index]
)

hazard_federated = local_models[client_index].predict_cumulative_hazard_function(
    X_list_aligned[client_index]
)
from matplotlib import pyplot as plt

for surv in [survival_local, survival_federated]:
    for i, s in enumerate(surv[:n_lines]):
        plt.step(s.x, s.y, where="post", label=str(i))
    plt.ylabel("Survival probability")
    plt.xlabel("Time in days")
    plt.legend()
    plt.grid(True)
    plt.show()


for hazard in [hazard_local, hazard_federated]:
    for i, s in enumerate(hazard[:n_lines]):
        plt.step(s.x, s.y, where="post", label=str(i))
    plt.ylabel("Cumulative hazard")
    plt.xlabel("Time in days")
    plt.legend()
    plt.grid(True)
    plt.show()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

federated_rsf-0.1.0.tar.gz (128.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

federated_rsf-0.1.0-py3-none-any.whl (14.6 kB view details)

Uploaded Python 3

File details

Details for the file federated_rsf-0.1.0.tar.gz.

File metadata

  • Download URL: federated_rsf-0.1.0.tar.gz
  • Upload date:
  • Size: 128.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for federated_rsf-0.1.0.tar.gz
Algorithm Hash digest
SHA256 9bedf2a23fef4a63151081bb9c8307489e4b11e2ecdf2b529f3fdac3d6487a04
MD5 47db13584c9d4caeeddf7fc2e8d6b01e
BLAKE2b-256 bb93b288fbaf0aebc0e779781f670844f17a84744f538d24a1b5a5f30c59c515

See more details on using hashes here.

Provenance

The following attestation bundles were made for federated_rsf-0.1.0.tar.gz:

Publisher: publish.yml on HauschildLab/FRSF4POD

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file federated_rsf-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: federated_rsf-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 14.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for federated_rsf-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c67eabaa9a56fe73c1c5925acd24d9bdbb7771621f0c60d2cb8ebbb584017d46
MD5 7c7437d0eb86d3e55f4cbaf69ca7f50e
BLAKE2b-256 beb3347780ca786d3dff06efa5bdcc42321582f623a0ce9f7faa492837e633d3

See more details on using hashes here.

Provenance

The following attestation bundles were made for federated_rsf-0.1.0-py3-none-any.whl:

Publisher: publish.yml on HauschildLab/FRSF4POD

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page