EHR generation with prompt learning by language models.
Project description
PromptEHR
Wang, Zifeng and Sun, Jimeng. (2022). PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning. EMNLP'22.
Usage
Get pretrained PromptEHR model (learned on MIMIC-III sequence EHRs) in three lines:
from promptehr import PromptEHR
model = PromptEHR()
model.from_pretrained()
A jupyter example is available at https://github.com/RyanWangZf/PromptEHR/blob/main/example/demo_promptehr.ipynb.
How to install
Install the correct PyTorch
version by referring to https://pytorch.org/get-started/locally/.
Then try to install PromptEHR
by
pip install git+https://github.com/RyanWangZf/PromptEHR.git
or
pip install promptehr
Load demo synthetic EHRs (generated by PromptEHR)
from promptehr import load_synthetic_data
data = load_synthetic_data()
Use PromptEHR for generation
from promptehr import SequencePatient
from promptehr import load_synthetic_data
from promptehr import PromptEHR
# init model
model = PromptEHR()
model.from_pretrained()
# load input data
demo = load_synthetic_data(n_sample=1000) # we have 10,000 samples in total
# build the standard input data for train or test PromptEHR models
seqdata = SequencePatient(data={'v':demo['visit'], 'y':demo['y'], 'x':demo['feature'],},
metadata={
'visit':{'mode':'dense'},
'label':{'mode':'tensor'},
'voc':demo['voc'],
'max_visit':20,
}
)
# you can try to fit on this data by
# model.fit(seqdata)
# start generate
# n: the target total number of samples to generate
# n_per_sample: based on each sample, how many fake samples will be generated
# the output will have the same format of `SequencePatient`
fake_data = model.predict(seqdata, n=1000, n_per_sample=10)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
PromptEHR-0.0.1.tar.gz
(44.7 kB
view hashes)
Built Distribution
PromptEHR-0.0.1-py3-none-any.whl
(51.3 kB
view hashes)
Close
Hashes for PromptEHR-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 51786bfb12e55c926d917e0e8d6fa3573aa1c23b828600949927f4c3f032e2c2 |
|
MD5 | 6a3af7fc92d82ecf25fb2f2e37731554 |
|
BLAKE2b-256 | 4c36a1fd08f71207bef0075de75b256373247fec1bd7d34b0d391d10a3563aed |