Module for sepsis predictions
Project description
Disease-predictions module
Instruction
Predictions sepsis is a module based on pandas, torch, and scikit-learn that allows users to perform simple operations with the MIMIC dataset. With this module, using just a few functions, you can train your model to predict whether some patients have certain diseases or not. By default, the module is designed to train and predict sepsis. The module also allows users to change different names of tables to aggregate data from.
Installation
To install the module, use the following command:
pip install sickness-screening
or
pip3 install sickness-screening
Usage
You can import functions from the module into your Python file to aggregate data from MIMIC, fill empty spots, compress data between patients, and train your model.
Examples
Aggregate patient diagnoses Data
import sickness_screening as ss
ss.get_diagnoses_data(patient_diagnoses_csv='path_to_patient_diagnoses.csv',
all_diagnoses_csv='path_to_all_diagnoses.csv',
output_file_csv='gottenDiagnoses.csv')
Aggregate patient ssir Data
import sickness_screening as ss
ss.get_analasys_data(chartevents_csv='chartevents.csv', subject_id_col='subject_id', itemid_col='itemid',
charttime_col='charttime', value_col='value', valuenum_col='valuenum', valueuom_col='valueuom',
itemids=None, rest_columns=None, output_csv='ssir.csv'):
Combine Diagnoses and SSIR Data
import sickness_screening as ss
ss.combine_data(gotten_diagnoses_csv='gottenDiagnoses.csv',
ssir_csv='path_to_ssir.csv',
output_file='diagnoses_and_ssir.csv')
Aggregate patient blood analysis data from chartevents.csv and labevents.csv and combine it with diagnoses and SSIR Data
import sickness_screening as ss
ss.merge_and_get_data(diagnoses_and_ssir_csv='diagnoses_and_ssir.csv',
blood_csv='path_to_blood.csv',
chartevents_csv='path_to_chartevents.csv',
output_csv='merged_data.csv')
)
Compress Data by patient
import sickness_screening as ss
ss.compress(df_to_compress='balanced_data.csv',
output_csv='compressed_data.csv')
Choose top non-sepsis patients to balance
import sickness_screening as ss
ss.choose(compressed_df_csv='compressed_data.csv',
output_file='final_balanced_data.csv')
Fill missing values with mode
import sickness_screening as ss
ss.fill_values(balanced_csv='final_balanced_data.csv',
strategy='most_frequent',
output_csv='filled_data.csv')
Train model
import sickness_screening as ss
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import MinMaxScaler
model = ss.train_model(df_to_train_csv='filled_data.csv',
categorical_col=['Large Platelets'],
columns_to_train_on=['Amylase'],
model=RandomForestClassifier(),
single_cat_column='White Blood Cells',
has_disease_col='has_sepsis',
subject_id_col='subject_id',
valueuom_col='valueuom',
scaler=MinMaxScaler(),
random_state=42,
test_size=0.2)
For example, you can insert models like CatBoostClassifier or SVC with different kernels.
CatBoostClassifier:
class_weights = {0: 1, 1: 15}
clf = CatBoostClassifier(loss_function='MultiClassOneVsAll', class_weights=class_weights, iterations=50, learning_rate=0.1, depth=5)
clf.fit(X_train, y_train)
SVC using Gaussian kernel with radial basis function (RBF):
class_weights = {0: 1, 1: 13}
param_dist = {
'C': reciprocal(0.1, 100),
'gamma': reciprocal(0.01, 10),
'kernel': ['rbf']
}
svm_model = SVC(class_weight=class_weights, random_state=42)
random_search = RandomizedSearchCV(
svm_model,
param_distributions=param_dist,
n_iter=10,
cv=5,
scoring=make_scorer(recall_score, pos_label=1),
n_jobs=-1
)
The Second Method (Transformers TabNet and DeepFM)
Collecting features into a dataset
You can choose any features, but we will take 4 as in MEWS (Modified Early Warning Score) to predict sepsis in the first hours of a patient's hospital stay:
- Systolic blood pressure
- Heart rate
- Respiratory rate
- Temperature
item_ids_set = set(item_ids)
with open(file_path) as f:
headers = f.readline().replace('\n', '').split(',')
i = 0
for line in tqdm(f):
values = line.replace('\n', '').split(',')
subject_id = values[0]
item_id = values[6]
valuenum = values[8]
if item_id in item_ids_set:
if subject_id not in result:
result[subject_id] = {}
result[subject_id][item_id] = valuenum
i += 1
table = pd.DataFrame.from_dict(result, orient='index')
table['subject_id'] = table.index
item_ids = [str(x) for x in [225309, 220045, 220210, 223762]]
Adding the target
target_subjects = drgcodes.loc[drgcodes['drg_code'].isin([870, 871, 872]), 'subject_id']
merged_data.loc[merged_data['subject_id'].isin(target_subjects), 'diagnosis'] = 1
Filling in gaps using the NoNa library. This algorithm fills in gaps using various machine learning methods, we use StandardScaler, Ridge and RandomForestClassifier
nona(
data=X,
algreg=make_pipeline(StandardScaler(with_mean=False), Ridge(alpha=0.1)),
algclass=RandomForestClassifier(max_depth=2, random_state=0)
)
Addressing class imbalance using SMOTE
smote = SMOTE(random_state=random_state)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
Training the TabNet model. TabNet is an extension of pyTorch. First, we use semi-supervised pretraining with TabNetPretrainer, then create and train a classification model using TabNetClassifier
unsupervised_model = TabNetPretrainer(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=pretraining_lr),
mask_type=mask_type
)
unsupervised_model.fit(
X_train=X_train.values,
eval_set=[X_val.values],
pretraining_ratio=pretraining_ratio,
)
clf = TabNetClassifier(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=training_lr),
scheduler_params=scheduler_params,
scheduler_fn=torch.optim.lr_scheduler.StepLR,
mask_type=mask_type
)
clf.fit(
X_train=X_train.values, y_train=y_train.values,
eval_set=[(X_val.values, y_val.values)],
eval_metric=['auc'],
max_epochs=max_epochs,
patience=patience,
from_unsupervised=unsupervised_model
)
Training the DeepFM model
deepfm = DeepFM("ranking", data_info, embed_size=16, n_epochs=2,
lr=1e-4, lr_decay=False, reg=None, batch_size=1,
num_neg=1, use_bn=False, dropout_rate=None,
hidden_units="128,64,32", tf_sess_config=None)
deepfm.fit(train_data, verbose=2, shuffle=True, eval_data=eval_data,
metrics=["loss", "balanced_accuracy", "roc_auc", "pr_auc",
"precision", "recall", "map", "ndcg"])
Viewing the obtained metrics
result = loaded_clf.predict(X_test.values)
accuracy = (result == y_test.values).mean()
precision = precision_score(y_test.values, result)
recall = recall_score(y_test.values, result)
f1 = f1_score(y_test.values, result)
Visualization of 2 PCA components was performed
The distribution by components is presented below:
| Load on the first component | Load on the second component | |
|---|---|---|
| Heart rate | -0.101450 | 0.991611 |
| Temperature | 0.001178 | 0.013098 |
| Systolic BP | 0.994771 | 0.100169 |
| Respiratory rate | 0.011673 | 0.080573 |
| MEWS | -0.000660 | 0.003313 |
No patterns were found.
A variational encoder was trained to build a separable 2D space
We can see that they overlap and are inseparable.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sickness_screening-1.0.1.tar.gz.
File metadata
- Download URL: sickness_screening-1.0.1.tar.gz
- Upload date:
- Size: 13.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.15.0 pkginfo/1.8.3 requests/2.27.1 setuptools/44.1.1 requests-toolbelt/1.0.0 tqdm/4.64.1 PyPy/7.3.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
61af2a6056644070a749d7faeb6afb0233420e7f95f4bdc9fdee06b4ac1caac9
|
|
| MD5 |
1c0ef160e700d71e6bfd5d02f4df5d21
|
|
| BLAKE2b-256 |
19d037da99f311da8c4af7b5c7bd39e73837437c46ac189d6dc7074aa52edb37
|
File details
Details for the file sickness_screening-1.0.1-py2-none-any.whl.
File metadata
- Download URL: sickness_screening-1.0.1-py2-none-any.whl
- Upload date:
- Size: 19.0 kB
- Tags: Python 2
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.15.0 pkginfo/1.8.3 requests/2.27.1 setuptools/44.1.1 requests-toolbelt/1.0.0 tqdm/4.64.1 PyPy/7.3.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5d85fbe85077dd3895f5108067c8e4a3b32dbd9f88728f393488f6ed8e2f3b15
|
|
| MD5 |
6dbcb49e4ea71446aa650cb5cdb665fa
|
|
| BLAKE2b-256 |
82792773b0281152b7fbed12c170aedaa61746069c0c2f18bcee7a0197994742
|