EHR generation with prompt learning by language models.
Project description
PromptEHR
Wang, Zifeng and Sun, Jimeng. (2022). PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning. EMNLP'22.
Usage
Get pretrained PromptEHR model (learned on MIMIC-III sequence EHRs) in three lines:
from promptehr import PromptEHR
model = PromptEHR()
model.from_pretrained()
A jupyter example is available at https://github.com/RyanWangZf/PromptEHR/blob/main/example/demo_promptehr.ipynb.
How to install
Install the correct PyTorch
version by referring to https://pytorch.org/get-started/locally/.
Then try to install PromptEHR
by
pip install git+https://github.com/RyanWangZf/PromptEHR.git
or
pip install promptehr
Load demo synthetic EHRs (generated by PromptEHR)
from promptehr import load_synthetic_data
data = load_synthetic_data()
Use PromptEHR for generation
from promptehr import SequencePatient
from promptehr import load_synthetic_data
from promptehr import PromptEHR
# init model
model = PromptEHR()
model.from_pretrained()
# load input data
demo = load_synthetic_data(n_sample=1000) # we have 10,000 samples in total
# build the standard input data for train or test PromptEHR models
seqdata = SequencePatient(data={'v':demo['visit'], 'y':demo['y'], 'x':demo['feature'],},
metadata={
'visit':{'mode':'dense'},
'label':{'mode':'tensor'},
'voc':demo['voc'],
'max_visit':20,
}
)
# you can try to fit on this data by
# model.fit(seqdata)
# start generate
# n: the target total number of samples to generate
# n_per_sample: based on each sample, how many fake samples will be generated
# the output will have the same format of `SequencePatient`
fake_data = model.predict(seqdata, n=1000, n_per_sample=10)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
PromptEHR-0.0.2.tar.gz
(44.8 kB
view hashes)
Built Distribution
PromptEHR-0.0.2-py3-none-any.whl
(51.4 kB
view hashes)
Close
Hashes for PromptEHR-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d5a3d521da33d74c3b4b99210dceaeaa899bb2d77c9655036c357f3e13865b0b |
|
MD5 | 934fbf96e77ed84e42f121644aa02f0f |
|
BLAKE2b-256 | 0f7ae85b914e47353c767870b346cf3d09b044cd42ab8ba5f415824b20f136ab |