A discrete-time survival analysis model with Kaplan-Meier inspired loss.
Project description
KMNet: Discrete-Time Survival Analysis with Deep Learning
⚡ State-of-the-art survival analysis with deep learning and ranking losses ⚡
Features • Installation • Quick Start • Demo Notebook • Citation
🎯 Use Cases & Applications
🏥 Clinical ResearchPatient survival analysis |
🔧 Reliability EngineeringEquipment failure prediction |
📊 Business AnalyticsCustomer churn prediction |
KMNet combines the power of deep neural networks with Kaplan-Meier inspired ranking losses to deliver state-of-the-art performance in discrete-time survival analysis. Perfect for clinical research, reliability engineering, and time-to-event prediction.
It extends standard neural survival models by incorporating a novel Kaplan-Meier inspired rank loss, allowing the model to learn not just from local hazard rates but also from global ranking constraints inherent in survival data.
This library provides a high-performance implementation using PyTorch JIT to speed up custom loss calculations, making it suitable for large-scale survival datasets.
🚀 Key Features
🎯 Advanced Methodology
|
⚡ Performance
|
📊 Why KMNet?
| Feature | Traditional Methods | DeepSurv/Cox-PH | KMNet |
|---|---|---|---|
| Handles Non-Linear Effects | ❌ | ✅ | ✅ |
| Captures Ranking Information | ⚠️ Partial | ❌ | ✅ |
| Discrete Time Bins | ❌ | ❌ | ✅ |
| GPU Acceleration | ❌ | ✅ | ✅ |
| Flexible Loss Functions | ❌ | ⚠️ Limited | ✅ |
| JIT-Optimized | N/A | ❌ | ✅ |
📦 Installation
You can install KMNet directly from the source:
git clone https://github.com/yuvrajiro/KMNet.git
cd KMNet
pip install .
Requirements
- Python >= 3.7
- PyTorch >= 1.7.0
- NumPy, Pandas
- torchtuples
- pycox
- numba
⚡ Quick Start
Here is a complete example of how to use KMNet on a synthetic dataset. You can also check out the interactive demo notebook.
1. Data Preparation
KMNet requires the target variable (time and event) to be discretized. We use label_transforms from pycox for this.
import numpy as np
import pandas as pd
from kmnet.model import KMNet
# Generate synthetic data
def make_data(n=1000):
X = np.random.randn(n, 5).astype('float32')
T = np.random.exponential(1 / (0.1 * np.exp(0.5 * X[:, 0])))
C = np.random.exponential(1 / 0.05, size=n)
time = np.minimum(T, C)
event = (T <= C).astype('float32')
return X, time, event
X, time, event = make_data()
# Discretize time into 20 bins
num_durations = 20
labtrans = KMNet.label_transform(num_durations)
get_target = lambda df: (df['duration'].values, df['event'].values)
df = pd.DataFrame({'duration': time, 'event': event})
y = labtrans.fit_transform(*get_target(df))
2. Define the Neural Network
You can use any PyTorch network architecture. The output dimension must match the number of time bins.
import torch.nn as nn
in_features = X.shape[1]
out_features = labtrans.out_features
net = nn.Sequential(
nn.Linear(in_features, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, out_features)
)
3. Train the Model
Initialize KMNet and fit it to the data.
# Initialize model with the discretized time grid
model = KMNet(net, duration_index=labtrans.cuts)
# Train
batch_size = 64
epochs = 10
model.fit(X, y, batch_size, epochs, verbose=True)
4. Prediction and Visualization
Predict survival functions and plot them.
import matplotlib.pyplot as plt
# Predict survival probabilities for the first 5 samples
surv_df = model.predict_surv_df(X[:5])
# Plot
plt.figure(figsize=(10, 6))
for col in surv_df.columns:
plt.step(surv_df.index, surv_df[col], where="post")
plt.ylabel("Survival Probability")
plt.xlabel("Time")
plt.title("Predicted Survival Curves")
plt.show()
🔬 Mathematical Background
KMNet models the discrete conditional survival $p(t | x)$. The survival function is given by:
$$ \boxed{S(t | x) = \prod_{k=0}^{t} p(k | x)} $$
The loss function $\mathcal{L}$ is a weighted sum of two components:
$$ \boxed{\mathcal{L} = \alpha \mathcal{L}{bce} + (1 - \alpha)\lambda \mathcal{L}{Rank}} $$
📖 Click to expand detailed explanation
Loss Components
-
$\mathcal{L}_{bce}$ (Negative Log-Likelihood):
- Standard survival loss for discrete time
- Ensures the model fits the observed event times
- Handles censored data correctly
-
$\mathcal{L}_{Rank}$ (Rank Loss):
- Enforces correct ordering: $S(T_i | x_i) < S(T_i | x_j)$ if $T_i < T_j$
- Inspired by the Kaplan-Meier estimator
- Maintains global structure of survival curves
- Differentiable approximation using softplus/exponential penalties
📊 Performance & Optimization
⚡ Speed Benchmarks
The KMNet class leverages TorchScript (JIT) to compile the custom loss functions, eliminating Python interpreter overhead.
|
Benchmark Setup:
Results:
|
Scalability:
Tested on Intel i7, 16GB RAM |
🛠️ Advanced Configuration
You can customize the loss function parameters when initializing the model:
base: 'nll' (Negative Log-Likelihood) or 'bce' (Binary Cross-Entropy)rank_mode: 'full' (CDF-based) or 'conditional' rankingrank_penalty: 'softplus' or 'exp' for the ranking loss formulation
from kmnet.model import KMLoss
loss = KMLoss(
alpha=0.5, # Weight for NLL vs Rank loss
sigma=0.1, # Kernel width for soft ranking
base='nll', # 'nll' or 'bce'
rank_mode='full', # 'full' (CDF-based) or 'conditional'
rank_penalty='softplus' # 'softplus' or 'exp'
)
model = KMNet(net, loss=loss, duration_index=labtrans.cuts)
💬 Getting Help
- 📖 Documentation: Check out our examples and API docs
- 💡 Issues: Found a bug? Open an issue
- 💬 Discussions: Have questions? Start a discussion
- 📧 Email: For collaboration inquiries: your.email@example.com
🤝 Contributing
We welcome contributions! KMNet is an open-source research project.
How to contribute
- Fork the repository
- Create your feature branch (
git checkout -b feature/AmazingFeature) - Commit your changes (
git commit -m 'Add some AmazingFeature') - Push to the branch (
git push origin feature/AmazingFeature) - Open a Pull Request
Areas we'd love help with:
- 📝 Documentation improvements
- 🧪 Additional benchmarks and examples
- 🐛 Bug fixes and testing
- ✨ New features (e.g., competing risks, time-varying covariates)
📄 License
Distributed under the MIT License. See LICENSE for more information.
📚 Citation (Under Consideration)
If you use KMNet in your research, please cite:
@article{YourLastName2025KMNet,
title={KMNet: Optimizing Discrete-Time Survival Analysis with Rank-Based Loss},
author={YourLastName, Yuvraj},
journal={Journal Name (Under Review)},
year={2025},
url={https://github.com/yuvrajiro/KMNet}
}
🏆 Acknowledgments
KMNet builds upon excellent work from:
- PyCox for survival analysis utilities
- TorchTuples for training infrastructure
- The PyTorch team for JIT compilation capabilities
Built with ❤️ for the survival analysis community
⭐ Star us on GitHub if KMNet helps your research!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file kmnet-0.1.1.tar.gz.
File metadata
- Download URL: kmnet-0.1.1.tar.gz
- Upload date:
- Size: 20.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0e480a35b168eb0d0778fea04763874d8d4d1027409e97597f5af2dbc88f1e2e
|
|
| MD5 |
e8bfb655a2dcd806a352ff695535648e
|
|
| BLAKE2b-256 |
b9b96fe4ec8a4a6c3b42ee8e2e7fb45f9780788a87b612251bb754c7a6099299
|
File details
Details for the file kmnet-0.1.1-py3-none-any.whl.
File metadata
- Download URL: kmnet-0.1.1-py3-none-any.whl
- Upload date:
- Size: 11.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9f6b58db09f8af76190b985ad790d86a4b3a7dd6d010b27ee642345f7a7dc87d
|
|
| MD5 |
9fb31a6b7acf4c96354ead567061171b
|
|
| BLAKE2b-256 |
e4ec9e042332619310aad4dc2d0b012144d5f306c4255e62abd592a1471ebda3
|