Skip to main content

Deep Learning for Survival Analysis

Project description

TorchLife

Survival Analysis using pytorch

This library takes a deep learning approach to Survival Analysis.

Install

pip install torchlife

How to use

We need a dataframe that has a column named 't' indicating time, and 'e' indicating a death event.

import pandas as pd
import numpy as np
url = "https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/datasets/rossi.csv"
df = pd.read_csv(url)
df.rename(columns={'week':'t', 'arrest':'e'}, inplace=True)
df.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
t e fin age race wexp mar paro prio
0 20 1 0 27 1 0 0 1 3
1 17 1 0 18 1 0 0 1 8
2 25 1 0 19 0 1 0 1 13
3 52 0 1 23 1 1 1 1 1
4 52 0 0 19 0 1 0 1 3
from torchlife.model import ModelHazard

model = ModelHazard('cox', lr=0.5)
model.fit(df)
λ, S = model.predict(df)
epoch train_loss valid_loss time
0 6.993955 10.741218 00:00
1 8.774823 14.736155 00:00
2 9.991431 16.564432 00:00
3 10.995527 17.174604 00:00
4 11.723181 16.920387 00:00
5 12.060142 15.983603 00:00
6 12.174074 14.553919 00:00
7 12.038597 12.683950 00:00
8 11.702325 10.452137 00:00
9 11.218502 7.981377 00:00
10 10.570101 5.209520 00:00
11 9.859859 4.039678 00:00
12 9.155064 3.643379 00:00
13 8.514476 2.742133 00:00
14 7.915660 3.074418 00:00
15 7.413548 2.585245 00:00
16 6.967895 2.710384 00:00
17 6.569957 2.544009 00:00
18 6.215098 2.433515 00:00
19 5.880322 2.342750 00:00

Let's plot the survival function for the 4th element in the dataframe:

x = df.drop(['t', 'e'], axis=1).iloc[2]
t = np.arange(df['t'].max())
model.plot_survival_function(t, x)

png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchlife-0.0.2.tar.gz (10.9 kB view hashes)

Uploaded Source

Built Distribution

torchlife-0.0.2-py3-none-any.whl (17.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page