Skip to main content

A Python library for healthcare AI

Project description

PyPI version Documentation status MyBinder GitHub stars GitHub forks Downloads Downloads

PyHealth is designed for both ML researchers and medical practitioners. We can make your healthcare AI applications easier to deploy and more flexible and customizable.


Introduction

PyHealth can support diverse electronic health records (EHRs) such as MIMIC and eICU and all OMOP-CDM based databases and provide various advanced deep learning algorithms for handling important healthcare tasks such as diagnosis-based drug recommendation, patient hospitalization and mortality prediction, and ICU length stay forecasting, etc.

Build a healthcare AI pipeline can be as short as 10 lines of code in PyHealth.

Modules

All healthcare tasks in our package follow a five-stage pipeline:

load dataset -> define task function -> build ML/DL model -> model training -> inference

! We try hard to make sure each stage is as separate as possibe, so that people can customize their own pipeline by only using our data processing steps or the ML models. Each step will call one module and we introduce them using an example.

An ML Pipeline Example

  • STEP 1: <pyhealth.datasets> provides a clean structure for the dataset, independent from the tasks. We support MIMIC-III, MIMIC-IV and eICU, as well as the standard OMOP-formatted data. The dataset is stored in a unified Patient-Visit-Event structure.

from pyhealth.datasets import MIMIC3Dataset
mimic3dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/mimiciii/1.4/",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    code_mapping={"NDC": "ATC"}, # map all NDC codes to ATC codes in these tables
)
  • STEP 2: <pyhealth.tasks> inputs the <pyhealth.datasets> object and defines how to process each pateint’s data into a set of samples for the tasks. In the package, we provide several task examples, such as drug recommendation and length of stay prediction.

from pyhealth.tasks import drug_recommendation_mimic3_fn
from pyhealth.datasets.splitter import split_by_patient
from torch.utils.data import DataLoader
from pyhealth.utils import collate_fn_dict

mimic3dataset.set_task(task_fn=drug_recommendation_mimic3_fn) # use default drugrec task
train_ds, val_ds, test_ds = split_by_patient(mimic3dataset, [0.8, 0.1, 0.1])

# create dataloaders
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate_fn_dict)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=collate_fn_dict)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, collate_fn=collate_fn_dict)
  • STEP 3: <pyhealth.models> provides the healthcare ML models using <pyhealth.datasets>. This module also provides model layers, such as pyhealth.models.RETAINLayer for building customized ML architectures. Our model layers can used as easily as torch.nn.Linear.

from pyhealth.models import Transformer

device = "cuda:0"
model = Transformer(
    dataset=mimic3dataset,
    tables=["conditions", "procedures"],
    mode="multilabel",
)
model.to(device)
  • STEP 4: <pyhealth.trainer> is the training manager with train_loader, the val_loader, val_metric, and specify other arguemnts, such as epochs, optimizer, learning rate, etc. The trainer will automatically save the best model and output the path in the end.

from pyhealth.trainer import Trainer
from pyhealth.metrics import pr_auc_multilabel
import torch

trainer = Trainer(enable_logging=True, output_path="../output", device=device)
trainer.fit(model,
    train_loader=train_loader,
    epochs=10,
    optimizer_class=torch.optim.Adam,
    optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
    val_loader=val_loader,
    val_metric=pr_auc_multilabel,
)
# Best model saved to: ../output/221004-015401/best.ckpt
  • STEP 5: <pyhealth.metrics> provides: (i) common evaluation metrics and the usage is the same as <pyhealth.metrics>; (ii) metrics (weighted by patients) for patient-level tasks; (iii) special metrics in healthcare, such as drug-drug interaction (DDI) rate.

from pyhealth.evaluator import evaluate
from pyhealth.metrics import accuracy_multilabel, jaccard_multilabel, f1_multilabel

# load best model and do inference
model = trainer.load_best_model(model)
y_gt, y_prob, y_pred = evaluate(model, test_loader, device)

jaccard = jaccard_multilabel(y_gt, y_pred)
accuracy = accuracy_multilabel(y_gt, y_pred)
f1 = f1_multilabel(y_gt, y_pred)
prauc = pr_auc_multilabel(y_gt, y_prob)

print("jaccard: ", jaccard)
print("accuracy: ", accuracy)
print("f1: ", f1)
print("prauc: ", prauc)

Medical Code Map

  • <pyhealth.codemap> provides two core functionalities: (i) looking up information for a given medical code (e.g., name, category, sub-concept); (ii) mapping codes across coding systems (e.g., ICD9CM to CCSCM). This module can be easily applied to your research.

  • For code mapping between two coding systems

from pyhealth.medcode import CrossMap
codemap = CrossMap("ICD9CM", "CCSCM")
codemap.map("82101") # use it like a dict

codemap = CrossMap("NDC", "ATC", level=3)
codemap.map("00527051210")
  • For code ontology lookup within one system

from pyhealth.medcode import InnerMap
ICD9CM = InnerMap("ICD9CM")
ICD9CM.lookup("428.0") # get detailed info
ICD9CM.get_ancesteros("428.0") # get parents

Medical Code Tokenizer

  • <pyhealth.tokenizer> is used for transformations between string-based tokens and integer-based indices, based on the overall token space. We provide flexible functions to tokenize 1D, 2D and 3D lists. This module can be used in many other scenarios.

from pyhealth.tokenizer import Tokenizer

# Example: we use a list of ATC3 code as the token
token_space = ['A01A', 'A02A', 'A02B', 'A02X', 'A03A', 'A03B', 'A03C', 'A03D', \
        'A03F', 'A04A', 'A05A', 'A05B', 'A05C', 'A06A', 'A07A', 'A07B', 'A07C', \
        'A12B', 'A12C', 'A13A', 'A14A', 'A14B', 'A16A']
tokenizer = Tokenizer(tokens=token_space, special_tokens=["<pad>", "<unk>"])

# 2d encode
tokens = [['A03C', 'A03D', 'A03E', 'A03F'], ['A04A', 'B035', 'C129']]
indices = tokenizer.batch_encode_2d(tokens) # [[8, 9, 10, 11], [12, 1, 1, 0]]

# 2d decode
indices = [[8, 9, 10, 11], [12, 1, 1, 0]]
tokens = tokenizer.batch_decode_2d(indices) # [['A03C', 'A03D', 'A03E', 'A03F'], ['A04A', '<unk>', '<unk>']]

Users can customize their healthcare AI pipeline as simply as calling one module

  • process your OMOP data via pyhealth.datasets

  • process the open eICU (e.g., MIMIC) data via pyhealth.datasets

  • define your own task on existing databases via pyhealth.tasks

  • use existing healthcare models or build upon it (e.g., RETAIN) via pyhealth.models.

  • code map between for conditions and medicaitons via pyhealth.codemap.


Datasets

We provide the following datasets for general purpose healthcare AI research:

Dataset

Module

Year

Information

MIMIC-III

pyhealth.datasets.MIMIC3BaseDataset

2016

MIMIC-III Clinical Database

MIMIC-IV

pyhealth.datasets.MIMIC4BaseDataset

2020

MIMIC-IV Clinical Database

eICU

pyhealth.datasets.eICUBaseDataset

2018

eICU Collaborative Research Database

OMOP

pyhealth.datasets.OMOPBaseDataset

OMOP-CDM schema based dataset

Machine/Deep Learning Models

Model Name

Type

Module

Year

Reference

Logistic Regression (LR)

classifical ML

pyhealth.models.MLModel

sklearn.linear_model.LogisticRegression

Random Forest (RF)

classifical ML

pyhealth.models.MLModel

sklearn.ensemble.RandomForestClassifier

Neural Networks (NN)

classifical ML

pyhealth.models.MLModel

sklearn.neural_network.MLPClassifier

Convolutional Neural Network (CNN)

deep learning

pyhealth.models.CNN

1989

Handwritten Digit Recognition with a Back-Propagation Network

Recurrent Neural Nets (RNN)

deep Learning

pyhealth.models.RNN

2011

Recurrent neural network based language model

Transformer

deep Learning

pyhealth.models.Transformer

2017

Atention is All you Need

RETAIN

deep Learning

pyhealth.models.RETAIN

2016

RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism

GAMENet

deep Learning

pyhealth.models.GAMENet

2019

GAMENet: Graph Attention Mechanism for Explainable Electronic Health Record Prediction

MICRON

deep Learning

pyhealth.models.MICRON

2021

Change Matters: Medication Change Prediction with Recurrent Residual Networks

SafeDrug

deep Learning

pyhealth.models.SafeDrug

2021

SafeDrug: Dual Molecular Graph Encoders for Recommending Effective and Safe Drug Combinations

Benchmark on Healthcare Tasks

  • Here is a temporary benchmark doc on healthcare tasks. We will put the results in this section below.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyhealth-1.0a2.tar.gz (54.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyhealth-1.0a2-py2.py3-none-any.whl (77.4 kB view details)

Uploaded Python 2Python 3

File details

Details for the file pyhealth-1.0a2.tar.gz.

File metadata

  • Download URL: pyhealth-1.0a2.tar.gz
  • Upload date:
  • Size: 54.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.12

File hashes

Hashes for pyhealth-1.0a2.tar.gz
Algorithm Hash digest
SHA256 a8fd146d6446e7e31ed6d70644fe1bcd664540006bc9614c5c1c109de8ccc4aa
MD5 4e12ed4e370e2d0c4c666cc7fbd0df33
BLAKE2b-256 7cab9fb21ece8d439ada3aa2e175a304ab47603aad112d3c63ab11c243a53504

See more details on using hashes here.

File details

Details for the file pyhealth-1.0a2-py2.py3-none-any.whl.

File metadata

  • Download URL: pyhealth-1.0a2-py2.py3-none-any.whl
  • Upload date:
  • Size: 77.4 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.12

File hashes

Hashes for pyhealth-1.0a2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 86c177e1997eb89117f8841659e38f4ac4f7f1dde9923f279cc6b0600083055a
MD5 1e11becd28a1b5523fb2b24cf8cbdeef
BLAKE2b-256 219233b50ee5ec4148359d3d9b2ca92bab674f9ae9038af78efd1eb33a87969c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page