Deep Learning for Survival Analysis
Project description
TorchLife
Survival Analysis using pytorch
This library takes a deep learning approach to Survival Analysis.
Install
pip install torchlife
How to use
We need a dataframe that has a column named 't' indicating time, and 'e' indicating a death event.
import pandas as pd
import numpy as np
url = "https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/datasets/rossi.csv"
df = pd.read_csv(url)
df.rename(columns={'week':'t', 'arrest':'e'}, inplace=True)
df.head()
<style scoped>
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
t | e | fin | age | race | wexp | mar | paro | prio | |
---|---|---|---|---|---|---|---|---|---|
0 | 20 | 1 | 0 | 27 | 1 | 0 | 0 | 1 | 3 |
1 | 17 | 1 | 0 | 18 | 1 | 0 | 0 | 1 | 8 |
2 | 25 | 1 | 0 | 19 | 0 | 1 | 0 | 1 | 13 |
3 | 52 | 0 | 1 | 23 | 1 | 1 | 1 | 1 | 1 |
4 | 52 | 0 | 0 | 19 | 0 | 1 | 0 | 1 | 3 |
from torchlife.model import ModelHazard
model = ModelHazard('cox', lr=0.5)
model.fit(df)
λ, S = model.predict(df)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 6.993955 | 10.741218 | 00:00 |
1 | 8.774823 | 14.736155 | 00:00 |
2 | 9.991431 | 16.564432 | 00:00 |
3 | 10.995527 | 17.174604 | 00:00 |
4 | 11.723181 | 16.920387 | 00:00 |
5 | 12.060142 | 15.983603 | 00:00 |
6 | 12.174074 | 14.553919 | 00:00 |
7 | 12.038597 | 12.683950 | 00:00 |
8 | 11.702325 | 10.452137 | 00:00 |
9 | 11.218502 | 7.981377 | 00:00 |
10 | 10.570101 | 5.209520 | 00:00 |
11 | 9.859859 | 4.039678 | 00:00 |
12 | 9.155064 | 3.643379 | 00:00 |
13 | 8.514476 | 2.742133 | 00:00 |
14 | 7.915660 | 3.074418 | 00:00 |
15 | 7.413548 | 2.585245 | 00:00 |
16 | 6.967895 | 2.710384 | 00:00 |
17 | 6.569957 | 2.544009 | 00:00 |
18 | 6.215098 | 2.433515 | 00:00 |
19 | 5.880322 | 2.342750 | 00:00 |
Let's plot the survival function for the 4th element in the dataframe:
x = df.drop(['t', 'e'], axis=1).iloc[2]
t = np.arange(df['t'].max())
model.plot_survival_function(t, x)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchlife-0.0.2.tar.gz
(10.9 kB
view details)
Built Distribution
torchlife-0.0.2-py3-none-any.whl
(17.2 kB
view details)
File details
Details for the file torchlife-0.0.2.tar.gz
.
File metadata
- Download URL: torchlife-0.0.2.tar.gz
- Upload date:
- Size: 10.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | bc40d37bf6a443ae280537b6a77c99f29be912e79805415c9025b7ef0d5d93b5 |
|
MD5 | 80fe12dbb149912f392de8258172c5b5 |
|
BLAKE2b-256 | 1f8852905c6e23bbb895efbb49451e9143149003c4bbbb1e2f5b3ef16a181310 |
File details
Details for the file torchlife-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: torchlife-0.0.2-py3-none-any.whl
- Upload date:
- Size: 17.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 64bbb1187baa6ee078d0a3351afaba121d1976a55f1d72c6a9f97873bc3a2a28 |
|
MD5 | a2c793b08d58ecb9136e416cb2c83c0b |
|
BLAKE2b-256 | 58b1088f6cf26b14c72538509111f7ded823b4e1263f8cd624e02f55e43e3bbc |