Composable PyTorch modules for survival predictive models.
Project description
Survcraft
Survcraft is a Python library for building survival prediction models with PyTorch. It provides composable neural network input modules, survival distribution modules, loss functions, and scikit-learn-style adapters for training models on right-censored time-to-event data.
The package is intended for research and experimentation where the model architecture, survival distribution, and training loss need to be mixed, matched, or extended.
Features
- PyTorch modules for common survival distributions, including exponential, Weibull, log-normal, Levy, inverse Gaussian, proportional hazards, accelerated failure time, and mixtures.
- Scikit-learn-compatible adapters for assembling and fitting survival predictors.
- Loss modules for full likelihood, partial likelihood, Brier score, squared loss, and batch-time classification losses.
- A simple feed-forward input adapter plus lower-level module APIs for custom input networks and custom survival modules.
- Optional plotting helpers for model outputs and training histories.
Installation
Survcraft requires Python 3.9 or newer.
The package is published on PyPI:
pip install survcraft
You can also install it directly from GitHub:
pip install git+https://github.com/compbiomed-unito/survcraft.git
Optional dependency groups are available for common workflows:
pip install "survcraft[plotting]"
pip install "survcraft[tutorials]"
pip install "survcraft[sksurv]"
pip install "survcraft[test]"
The core package depends on PyTorch, NumPy, and scikit-learn.
Quick Start
The high-level API combines three parts:
- an input adapter that maps features to raw survival parameters,
- a survival adapter that turns parameters into a survival distribution, and
- a loss module used during training.
import numpy as np
from survcraft import adapters as ad
from survcraft import loss_modules as lm
rng = np.random.default_rng(0)
X = rng.normal(size=(128, 4)).astype(np.float32)
y = np.zeros(128, dtype=[("event", "?"), ("time", "f4")])
y["event"] = rng.random(128) < 0.7
y["time"] = rng.uniform(0.1, 5.0, size=128)
model = ad.SurvivalPredictor(
input=ad.FeedForwardNetAdapter(hidden_sizes=[16, 16]),
survival=ad.WeibullSurvivalAdapter(),
loss=lm.FullLikelihoodLoss(),
epochs=50,
batch_size=32,
learning_rate=1e-3,
)
model.fit(X, y)
times = np.linspace(0.1, 5.0, 50, dtype=np.float32)
survival = model.predict("survival", X[:5], times)
failure = model.predict("failure", X[:5], times)
y is expected to be a NumPy structured array with boolean event and numeric
time fields. This is compatible with the common scikit-survival target
format.
Core Concepts
Input adapters
Input adapters construct PyTorch modules that map feature vectors to the raw parameters required by a survival module.
FeedForwardNetAdapterbuilds a configurable multilayer perceptron.LinearFunctionInputAdapteris useful for deterministic simulations and simple baseline models.
Custom input models can be integrated by following the adapter pattern in
survcraft.adapters.BaseInputAdapter.
Survival adapters
Survival adapters construct the survival distribution module used by the model. Available adapters include:
ExponentialSurvivalAdapterWeibullSurvivalAdapterLogNormalSurvivalAdapterLevySurvivalAdapterInverseGaussianSurvivalAdapterStepExpSurvivalAdapterProportionalHazardSurvivalAdapterAcceleratedFailureTimeSurvivalAdapterMixtureSurvivalAdapterFractalNoiseSurvivalAdapter
The resulting model can predict distribution quantities such as survival, failure, density, hazard, expected time, and median time when implemented by the selected survival module.
Loss modules
Training losses live in survcraft.loss_modules. The main public losses are:
FullLikelihoodLossPartialLikelihoodLossBrierLossSquaredLossBrierBatchTimesLossBCEBatchTimesLoss
Loss modules can be combined arithmetically:
loss = lm.FullLikelihoodLoss() + 0.1 * lm.BrierLoss()
Plotting
Plotting helpers require the optional plotting dependencies:
pip install "survcraft[plotting]"
from survcraft import plotting
plotting.plot_outputs(model, X[:20])
plotting.plot_training_history(model)
Tutorials
Notebook tutorials are included in the repository under tutorials/:
Basic usage.ipynbImplementing custom layers.ipynb
Install the tutorial dependencies with:
pip install "survcraft[tutorials]"
Some tutorials may also require optional survival-analysis dependencies:
pip install "survcraft[sksurv]"
Development
Clone the repository and install the package in editable mode:
git clone https://github.com/compbiomed-unito/survcraft.git
cd survcraft
pip install -e ".[test,plotting,sksurv]"
Run the test suite:
python -m pytest -q
Build the distribution artifacts:
python -m build
Before publishing to PyPI, validate the artifacts with:
python -m twine check dist/*
Project Status
Survcraft is currently an alpha-stage research library. Public APIs may change before a stable release.
License
Survcraft is distributed under the GNU Lesser General Public License v3.0. See LICENSE for the
full license text.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file survcraft-0.1.0.tar.gz.
File metadata
- Download URL: survcraft-0.1.0.tar.gz
- Upload date:
- Size: 26.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4ba211460e6a84202249053467d622ecd9541a6fe3bf290bbeb8d13e5feaac58
|
|
| MD5 |
8103b6075f456aba8892da520e2c9375
|
|
| BLAKE2b-256 |
2aba1023886db0635b9a815605ecc988136296998df2b22d7271bbddf7b24e15
|
File details
Details for the file survcraft-0.1.0-py3-none-any.whl.
File metadata
- Download URL: survcraft-0.1.0-py3-none-any.whl
- Upload date:
- Size: 27.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f77b5c45927f46c134686e66a0ed0781f5b6c11b530d548fe4933d650de1f745
|
|
| MD5 |
a6b7b26b843362ecedcd90b133a1be05
|
|
| BLAKE2b-256 |
3542c9842b2ffa63008eb189f2d57cb681c6352d0fb1a7bdff11e88ad29b331b
|