Skip to main content

CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models

Project description

CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models

CRISP-NAM (Competing Risks Interpretable Survival Prediction with Neural Additive Models), an interpretable neural additive model for competing risks survival analysis which extends the neural additive architecture to model cause-specific hazards while preserving feature-level interpretability.

Overview

This repository provides a comprehensive framework for competing risks survival analysis with interpretable neural additive models. CRISP-NAM combines the predictive power of deep learning with interpretability through feature-level shape functions, making it suitable for clinical and biomedical applications where understanding feature contributions is crucial.

Key Features

  • Interpretable Architecture: Neural additive models that provide feature-level interpretability through shape functions
  • Competing Risks Support: Native handling of multiple competing events in survival analysis
  • Comprehensive Evaluation: Nested cross-validation with robust performance metrics (AUC, Brier Score, Time-dependent C-index)
  • Hyperparameter Optimization: Automated tuning using Optuna with customizable search spaces
  • Rich Visualizations: Automated generation of feature importance plots and shape function visualizations
  • Multiple Training Modes: Standard training, hyperparameter tuning, and nested cross-validation
  • Baseline Comparisons: DeepHit implementation for benchmarking against state-of-the-art methods

Available Datasets

This repository includes 4 datasets: Framingham Heart Study, PBC, Support2 and Synthetic datasets. Detailed information is available in datasets.md.

Requirements

Python >=3.10

Repository Structure

crisp_nam/
├── blog/                                   # Blog
├── crisp_nam/                              # Main package
│   ├── metrics/
│   │   ├── __init__.py
│   │   ├── calibration.py
│   │   ├── discrimination.py
│   │   └── ipcw.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── crisp_nam_model.py
│   │   └── deephit_model.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── loss.py
│   │   ├── plotting.py
│   │   └── risk_cif.py
│   └── __init__.py
├── data_utils/                             # Data utilities
│   ├── __init__.py
│   ├── load_datasets.py
│   └── survival_datasets.py
├── datasets/                               # Dataset files and loaders
│   ├── metabric/
│   │   ├── cleaned_features_final.csv
│   │   └── label.csv
│   ├── framingham_dataset.py
│   ├── framingham.csv
│   ├── pbc_dataset.py
│   ├── pbc2.csv
│   ├── support_dataset.py
│   ├── support2.csv
│   ├── SurvivalDataset.py
│   ├── synthetic_comprisk.csv
│   └── synthetic_dataset.py
├── docs/                                   # Documentation of Pypi package.
├── results/                                # Results and outputs
│   ├── best_params/                        # Best parameters for dataset and model combinations
│   │   ├── best_params_framingham_deephit.yaml
│   │   ├── best_params_framingham.yaml
│   │   ├── best_params_pbc_deephit.yaml
│   │   ├── best_params_pbc.yaml
│   │   ├── best_params_support.yaml
│   │   ├── best_params_support2_deephit.yaml
│   │   ├── best_params_synthetic_deephit.yaml
│   │   └── best_params_synthetic.yaml
│   ├── logs/                               # Nested CV results and logs
│   │   ├── nested_cv_best_params_*.yaml
│   │   ├── nested_cv_detailed_metrics_*.csv
│   │   ├── nested_cv_metrics_*.xlsx
│   │   ├── nested_cv_raw_metrics_*.json
│   │   └── nested_cv_summary_metrics_*.csv
│   └── plots/                              # Generated plots
│       ├── nested_cv_feature_importance_risk_*_*.png
│       └── nested_cv_shape_functions_risk_*_*.png
├── training_scripts/                       # Training scripts
│   ├── config.yaml
│   ├── model_utils.py
│   ├── train_deephit_cuda.py
│   ├── train_deephit.py
│   ├── train_nested_cv.py                  # Nested cross-validation script
│   ├── train.py
│   ├── tune_optuna_optimized.py
│   └── tune_optuna.py

Install from source

  1. Clone the repository
git clone git@github.com:VectorInstitute/crisp-nam.git
  1. Install

via pip

cd crisp-nam
pip install -e

via uv

cd crisp-nam
uv sync

Training Scripts

The repository provides several specialized training scripts:

  • train.py: Standard model training with cross-validation and comprehensive evaluation
  • train_nested_cv.py: Robust nested cross-validation for unbiased performance estimation
  • tune_optuna.py: Hyperparameter optimization using Optuna's advanced algorithms
  • tune_optuna_optimized.py: Hyperparameter optimization using Optuna on a GPU.
  • train_deephit.py: DeepHit baseline implementation for comparative studies
  • train_deephit_cuda.py: DeepHit baseline implementation optimized for running on a GPU.

Each script supports extensive configuration through command-line arguments and YAML config files, enabling reproducible experiments and easy parameter sweeps.

Running training scripts

  1. Modify training parameters in training_scripts/train.py OR Run either of following commands to see CLI arguments for passing training parameters:

    python training_scripts/train.py --help
    
    uv run training_scripts/train.py --help
    
  2. Run the training script

    1. via python
    source .venv/bin/activate
    python training_scripts/train.py --dataset framingham
    
    1. via uv
    uv run training_scripts/train.py --dataset framingham
    

Running Nested Cross-Validation

The nested cross-validation script performs robust model evaluation with hyperparameter optimization using inner and outer cross-validation loops. It automatically generates performance metrics, feature importance plots, and shape function visualizations.

via python

python training_scripts/train_nested_cv.py --dataset framingham

via uv

uv run training_scripts/train_nested_cv.py --dataset framingham

Configuration Parameters

All parameters can be passed via command line or specified in a YAML config file:

  1. Dataset Configuration
  • --dataset (str): Dataset to use (choices: framingham, support, pbc, synthetic, default: framingham)
  • --scaling (str): Data scaling method for continuous features (choices: minmax, standard, none, default: standard)
  1. Training Parameters
  • --num_epochs (int): Number of training epochs (default: 250)
  • --batch_size (int): Batch size for training (default: 512)
  • --patience (int): Patience for early stopping (default: 10)
  1. Cross-Validation Configuration
  • --outer_folds (int): Number of outer CV folds (default: 5)
  • --inner_folds (int): Number of inner CV folds for hyperparameter tuning (default: 3)
  • --n_trials (int): Number of Optuna trials per inner fold (default: 20)
  1. Event Weighting
  • --event_weighting (str): Event weighting strategy (choices: none, balanced, custom, default: none)
  • --custom_event_weights (str): Custom weights for events (comma-separated, default: None)
  1. Other Parameters
  • --seed (int): Random seed for reproducibility (default: 42)
  • --config (str): Path to YAML config file (default: looks for config.yaml)

Examples

  1. Basic nested CV with default parameters:
python training_scripts/train_nested_cv.py --dataset pbc
  1. Customized nested CV with specific parameters:
python training_scripts/train_nested_cv.py \
    --dataset support \
    --outer_folds 10 \
    --inner_folds 5 \
    --n_trials 50 \
    --num_epochs 500 \
    --event_weighting balanced \
    --scaling minmax \
    --seed 123
  1. Using a config file:
python training_scripts/train_nested_cv.py --config my_config.yaml

Output Files

The script generates several output files in the current directory:

  1. Performance Metrics
  • nested_cv_summary_metrics_{dataset}.csv: Summary table with mean ± std metrics
  • nested_cv_detailed_metrics_{dataset}.csv: Detailed results for each fold
  • nested_cv_metrics_{dataset}.xlsx: Excel file with multiple sheets (Summary, Detailed, Metadata)
  • nested_cv_raw_metrics_{dataset}.json: Raw metrics dictionary for reproducibility
  1. Model Configuration
  • nested_cv_best_params_{dataset}.yaml: Aggregated best hyperparameters across all folds
  1. Visualizations
  • nested_cv_feature_importance_risk_{risk}_{dataset}.png: Feature importance plots
  • nested_cv_shape_functions_risk_{risk}_{dataset}.png: Shape function plots for top features

Results are saved to results/plots/:

Evaluation Metrics

The script computes the following metrics at different time quantiles (25%, 50%, 75%):

  1. AUC (Area Under the ROC Curve): Time-dependent AUC for discrimination
  • 0.5 = random, >0.7 = good, >0.8 = excellent
  1. TDCI (Time-Dependent Concordance Index): Harrell's C-index adapted for competing risks
  • 0.5 = random, >0.7 = good, >0.8 = excellent
  1. Brier Score: Calibration metric measuring prediction accuracy
  • 0 = perfect, <0.25 = good, >0.25 = poor

[!NOTE] For uv installation, please visit follow instructions in their official page.

Contributing

Contributions are welcome! Please open issues or submit pull requests.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crisp_nam-0.1.0.tar.gz (18.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

crisp_nam-0.1.0-py3-none-any.whl (22.1 kB view details)

Uploaded Python 3

File details

Details for the file crisp_nam-0.1.0.tar.gz.

File metadata

  • Download URL: crisp_nam-0.1.0.tar.gz
  • Upload date:
  • Size: 18.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for crisp_nam-0.1.0.tar.gz
Algorithm Hash digest
SHA256 6380ee2761e5b0b69ff0d90daec4505329c681e0d8f53020666a5cc27157003b
MD5 99e40b0ae00424eeced462dcd2f1d789
BLAKE2b-256 af8e06c6dd3f09c10143683541c33884b9ff21204fd10142acb3dbeca736db73

See more details on using hashes here.

File details

Details for the file crisp_nam-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: crisp_nam-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 22.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for crisp_nam-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f3e9d5624609059fbd0b6d997b9f1cc991ba84bad6441c9287081bf5e349b7d0
MD5 26750222c37f8f5ca7751e62f228038b
BLAKE2b-256 b865c625910cd35cde0eb95a2d3f68db748c9a4887c2219d0eb08fed56ba8208

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page