CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models
Project description
CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models
CRISP-NAM (Competing Risks Interpretable Survival Prediction with Neural Additive Models), an interpretable neural additive model for competing risks survival analysis which extends the neural additive architecture to model cause-specific hazards while preserving feature-level interpretability.
Overview
This repository provides a comprehensive framework for competing risks survival analysis with interpretable neural additive models. CRISP-NAM combines the predictive power of deep learning with interpretability through feature-level shape functions, making it suitable for clinical and biomedical applications where understanding feature contributions is crucial.
Key Features
- Interpretable Architecture: Neural additive models that provide feature-level interpretability through shape functions
- Competing Risks Support: Native handling of multiple competing events in survival analysis
- Comprehensive Evaluation: Nested cross-validation with robust performance metrics (AUC, Brier Score, Time-dependent C-index)
- Hyperparameter Optimization: Automated tuning using Optuna with customizable search spaces
- Rich Visualizations: Automated generation of feature importance plots and shape function visualizations
- Multiple Training Modes: Standard training, hyperparameter tuning, and nested cross-validation
- Baseline Comparisons: DeepHit implementation for benchmarking against state-of-the-art methods
Available Datasets
This repository includes 4 datasets: Framingham Heart Study, PBC, Support2 and Synthetic datasets. Detailed information is available in datasets.md.
Requirements
Python >=3.10
Repository Structure
crisp_nam/
├── blog/ # Blog
├── crisp_nam/ # Main package
│ ├── metrics/
│ │ ├── __init__.py
│ │ ├── calibration.py
│ │ ├── discrimination.py
│ │ └── ipcw.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── crisp_nam_model.py
│ │ └── deephit_model.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── loss.py
│ │ ├── plotting.py
│ │ └── risk_cif.py
│ └── __init__.py
├── data_utils/ # Data utilities
│ ├── __init__.py
│ ├── load_datasets.py
│ └── survival_datasets.py
├── datasets/ # Dataset files and loaders
│ ├── metabric/
│ │ ├── cleaned_features_final.csv
│ │ └── label.csv
│ ├── framingham_dataset.py
│ ├── framingham.csv
│ ├── pbc_dataset.py
│ ├── pbc2.csv
│ ├── support_dataset.py
│ ├── support2.csv
│ ├── SurvivalDataset.py
│ ├── synthetic_comprisk.csv
│ └── synthetic_dataset.py
├── docs/ # Documentation of Pypi package.
├── results/ # Results and outputs
│ ├── best_params/ # Best parameters for dataset and model combinations
│ │ ├── best_params_framingham_deephit.yaml
│ │ ├── best_params_framingham.yaml
│ │ ├── best_params_pbc_deephit.yaml
│ │ ├── best_params_pbc.yaml
│ │ ├── best_params_support.yaml
│ │ ├── best_params_support2_deephit.yaml
│ │ ├── best_params_synthetic_deephit.yaml
│ │ └── best_params_synthetic.yaml
│ ├── logs/ # Nested CV results and logs
│ │ ├── nested_cv_best_params_*.yaml
│ │ ├── nested_cv_detailed_metrics_*.csv
│ │ ├── nested_cv_metrics_*.xlsx
│ │ ├── nested_cv_raw_metrics_*.json
│ │ └── nested_cv_summary_metrics_*.csv
│ └── plots/ # Generated plots
│ ├── nested_cv_feature_importance_risk_*_*.png
│ └── nested_cv_shape_functions_risk_*_*.png
├── training_scripts/ # Training scripts
│ ├── config.yaml
│ ├── model_utils.py
│ ├── train_deephit_cuda.py
│ ├── train_deephit.py
│ ├── train_nested_cv.py # Nested cross-validation script
│ ├── train.py
│ ├── tune_optuna_optimized.py
│ └── tune_optuna.py
Install from source
- Clone the repository
git clone git@github.com:VectorInstitute/crisp-nam.git
- Install
via pip
cd crisp-nam
pip install -e
via uv
cd crisp-nam
uv sync
Training Scripts
The repository provides several specialized training scripts:
train.py: Standard model training with cross-validation and comprehensive evaluationtrain_nested_cv.py: Robust nested cross-validation for unbiased performance estimationtune_optuna.py: Hyperparameter optimization using Optuna's advanced algorithmstune_optuna_optimized.py: Hyperparameter optimization using Optuna on a GPU.train_deephit.py: DeepHit baseline implementation for comparative studiestrain_deephit_cuda.py: DeepHit baseline implementation optimized for running on a GPU.
Each script supports extensive configuration through command-line arguments and YAML config files, enabling reproducible experiments and easy parameter sweeps.
Running training scripts
-
Modify training parameters in
training_scripts/train.pyOR Run either of following commands to see CLI arguments for passing training parameters:python training_scripts/train.py --help
uv run training_scripts/train.py --help
-
Run the training script
- via
python
source .venv/bin/activate python training_scripts/train.py --dataset framingham
- via
uv
uv run training_scripts/train.py --dataset framingham
- via
Running Nested Cross-Validation
The nested cross-validation script performs robust model evaluation with hyperparameter optimization using inner and outer cross-validation loops. It automatically generates performance metrics, feature importance plots, and shape function visualizations.
via python
python training_scripts/train_nested_cv.py --dataset framingham
via uv
uv run training_scripts/train_nested_cv.py --dataset framingham
Configuration Parameters
All parameters can be passed via command line or specified in a YAML config file:
- Dataset Configuration
--dataset(str): Dataset to use (choices:framingham,support,pbc,synthetic, default:framingham)--scaling(str): Data scaling method for continuous features (choices:minmax,standard,none, default:standard)
- Training Parameters
--num_epochs(int): Number of training epochs (default:250)--batch_size(int): Batch size for training (default:512)--patience(int): Patience for early stopping (default:10)
- Cross-Validation Configuration
--outer_folds(int): Number of outer CV folds (default:5)--inner_folds(int): Number of inner CV folds for hyperparameter tuning (default:3)--n_trials(int): Number of Optuna trials per inner fold (default:20)
- Event Weighting
--event_weighting(str): Event weighting strategy (choices:none,balanced,custom, default:none)--custom_event_weights(str): Custom weights for events (comma-separated, default:None)
- Other Parameters
--seed(int): Random seed for reproducibility (default:42)--config(str): Path to YAML config file (default: looks forconfig.yaml)
Examples
- Basic nested CV with default parameters:
python training_scripts/train_nested_cv.py --dataset pbc
- Customized nested CV with specific parameters:
python training_scripts/train_nested_cv.py \
--dataset support \
--outer_folds 10 \
--inner_folds 5 \
--n_trials 50 \
--num_epochs 500 \
--event_weighting balanced \
--scaling minmax \
--seed 123
- Using a config file:
python training_scripts/train_nested_cv.py --config my_config.yaml
Output Files
The script generates several output files in the current directory:
- Performance Metrics
nested_cv_summary_metrics_{dataset}.csv: Summary table with mean ± std metricsnested_cv_detailed_metrics_{dataset}.csv: Detailed results for each foldnested_cv_metrics_{dataset}.xlsx: Excel file with multiple sheets (Summary, Detailed, Metadata)nested_cv_raw_metrics_{dataset}.json: Raw metrics dictionary for reproducibility
- Model Configuration
nested_cv_best_params_{dataset}.yaml: Aggregated best hyperparameters across all folds
- Visualizations
nested_cv_feature_importance_risk_{risk}_{dataset}.png: Feature importance plotsnested_cv_shape_functions_risk_{risk}_{dataset}.png: Shape function plots for top features
Results are saved to results/plots/:
Evaluation Metrics
The script computes the following metrics at different time quantiles (25%, 50%, 75%):
- AUC (Area Under the ROC Curve): Time-dependent AUC for discrimination
- 0.5 = random, >0.7 = good, >0.8 = excellent
- TDCI (Time-Dependent Concordance Index): Harrell's C-index adapted for competing risks
- 0.5 = random, >0.7 = good, >0.8 = excellent
- Brier Score: Calibration metric measuring prediction accuracy
- 0 = perfect, <0.25 = good, >0.25 = poor
[!NOTE] For
uvinstallation, please visit follow instructions in their official page.
Contributing
Contributions are welcome! Please open issues or submit pull requests.
License
This project is licensed under the MIT License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file crisp_nam-0.1.0.tar.gz.
File metadata
- Download URL: crisp_nam-0.1.0.tar.gz
- Upload date:
- Size: 18.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6380ee2761e5b0b69ff0d90daec4505329c681e0d8f53020666a5cc27157003b
|
|
| MD5 |
99e40b0ae00424eeced462dcd2f1d789
|
|
| BLAKE2b-256 |
af8e06c6dd3f09c10143683541c33884b9ff21204fd10142acb3dbeca736db73
|
File details
Details for the file crisp_nam-0.1.0-py3-none-any.whl.
File metadata
- Download URL: crisp_nam-0.1.0-py3-none-any.whl
- Upload date:
- Size: 22.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f3e9d5624609059fbd0b6d997b9f1cc991ba84bad6441c9287081bf5e349b7d0
|
|
| MD5 |
26750222c37f8f5ca7751e62f228038b
|
|
| BLAKE2b-256 |
b865c625910cd35cde0eb95a2d3f68db748c9a4887c2219d0eb08fed56ba8208
|