Code for Context Clues paper
Project description
Training Long Context Models on EHR Data
This repo contains code and pretrained models for the Context Clues paper. It is designed to enable training any model on HuggingFace on structured EHR data. It comes with Hydra configs + Wandb logging + PyTorch Lightning distributed training support.
It currently supports EHR data defined using the MEDS data standard or FEMR package.
📖 Table of Contents
- 🤗 Pretrained HuggingFace Models
- 📀 Installation
- 🚀 Quick Start
- 🏋️♀️ Training
- 📊 Evaluation
- 💊 MEDS Demo
- ℹ️ Other
- 🎓 Citation
🤗 Pretrained HuggingFace Models
Please see our HuggingFace Collection to download the following models pretrained from scratch on 2 billion tokens of deidentified structured EHR data:
| Model | Context Lengths |
|---|---|
| gpt | 512, 1024, 2048, 4096 |
| llama | 512, 1024, 2048, 4096 |
| mamba | 1024, 4096, 8192, 16384 |
| hyena | 1024, 4096, 8192, 16384 |
📀 Installation
Direct install:
pip install hf-ehr
For faster Mamba runs, install:
pip install mamba-ssm causal-conv1d
Development install:
conda create -n hf_env python=3.10 -y
conda activate hf_env
pip install -r requirements.txt --no-cache-dir
pip install -e .
# [Optional] If you haven't already created your **Tokenizers**, run the following. If you're on Carina, then skip this step.
cd hf_ehr/scripts/tokenizers
sbatch clmbr.sh # Takes ~5 seconds
sbatch desc.sh # Takes ~30 min
sbatch cookbook.sh # Takes many hours
🚀 Quick Start
Launch a GPT training run with the ability to configure common hyperparameters:
cd hf_ehr/scripts/carina
python3 main.py --model gpt2 --size base --tokenizer clmbr --context_length 1024 --dataloader approx --dataset v8 --is_run_local --is_force_refresh
Launch a Llama run on a MEDS dataset:
cd hf_ehr/scripts/carina
python3 main.py --model llama --size base --tokenizer clmbr --context_length 1024 --dataloader approx --dataset meds_mimic4_demo --is_run_local --is_force_refresh
To launch 4 GPT-base runs on one SLURM node (in parallel), and 4 Mamba runs on another SLURM node (in parallel):
cd hf_ehr/scripts/carina
# GPT runs
sbatch parallel_gpt.sh
# Mamba runs
sbatch parallel_mamba.sh
🏋️♀️ Training
We use Hydra to manage our configurations and PyTorch Lightning for training.
You can either overwrite the config files in configs/ or pass in CLI arguments to override the defaults.
There are 3 ways to launch a training run.
Easy Mode
Launch multiple runs in parallel on the same SLURM node (each job gets 1 GPU) using hf_ehr/scripts/carina/parallel_{model}.sh:
cd hf_ehr/scripts/carina
# Launch 4 gpt runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_gpt.sh
# Launch 4 bert runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_bert.sh
# Launch 4 hyena runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_hyena.sh
# Launch 4 mamba runs in parallel on the same node. See the file for the specific model versions run.
sbatch parallel_mamba.sh
Medium Mode
Launch one run on a SLURM node using hf_ehr/scripts/carina/{model}.sh:
cd hf_ehr/scripts/carina
# Launch GPT-2 base model on v8 dataset with CLMBRTokenizer, ApproxBatchSampler dataloader, and 2048 context length; force train from scratch and not resume prior run (even if exists)
python3 main.py --model gpt2 --size base --tokenizer clmbr --context_length 2048 --dataloader approx --dataset v8 --is_force_refresh
# Launch Mamba tiny model on v8 dataset with CookbookTokenizer, ApproxBatchSampler dataloader, and 16384 context length; resume prior run if exists
python3 main.py --model mamba --size tiny --tokenizer cookbook --context_length 16384 --dataloader approx --dataset v8
# Launch BERT-base model on v8 dataset with DescTokenizer, ApproxBatchSampler dataloader, and 4096 context length; resume prior run if exists; overwrite the default device assignment to GPU 1; give wandb run a name of `custom`
python3 main.py --model bert --size base --tokenizer desc --context_length 4096 --dataloader approx --dataset v8 --extra "+trainer.devices=[1] logging.wandb.name=custom"
# Run locally a GPT-2 large model on v8 AllTokens dataset with CLMBRTokenizer, ApproxBatchSampler dataloader, and 1024 context length
python3 main.py --model gpt2 --size large --tokenizer clmbr --context_length 2048 --dataloader approx --dataset v8-alltokens --is_run_local
# Launch Mamba tiny model on v8 dataset with CookbookTokenizer, ApproxBatchSampler dataloader, and 16384 context length; resume prior run if exists; run on 8 H100's
python3 main.py --model mamba --size tiny --tokenizer cookbook --context_length 16384 --dataloader approx --dataset v8 --partitions nigam-h100 --extra "trainer=multi_gpu trainer.devices=[0,1,2,3,4,5,6,7]"
General usage:
python3 main.py --model <model> --size <size> --tokenizer <tokenizer> --context_length <context_length> --dataloader <dataloader> --dataset <dataset> [--extra <extra>] [--partitions <partitions>] [--is_force_refresh] [--is_skip_base] [--is_run_local]
where...
<model>: str -- Architecture to use. Choices aregpt,bert,hyena,mamba<size>: str -- Model size to use. Choices aretiny,small,base,medium,large,huge<tokenizer>: str -- Tokenizer to use. Choices areclmbr,desc,cookbook<context_length>: int -- Context length to use<dataloader>: str -- Dataloader to use. Choices areapprox,exact<dataset>: str -- Dataset to use. Choices arev8,v8-alltokens,v9,v9-alltokens[--extra <extra>]: Optional[str] -- An optional string that will get appended to the end of thepython ../run.pycommand verbatim[--partitions <partitions>]: Optional[str] -- An optional string that specifies the partitions to use. Defaults tonigam-v100,gpufor gpt2 and BERT, andnigam-h100,nigam-a100for HYENA and MAMBA[--is_force_refresh]: Optional -- An optional flag that triggers a force refresh of the run (i.e., delete the existing run and start from scratch)[--is_skip_base]: Optional -- An optional flag that skips runningsource base.sh. Useful when runningparallel.shand we don't want to reinit the conda environment multiple times[--is_run_local]: Optional -- An optional flag that runs the script locally aspython run.pyinstead of as a SLURMsbatchcommand
Advanced Mode
Directly call run.py, which allows maximum flexibility for configs.
See the Config README for details on all config settings.
cd hf_ehr/scripts/carina
# Launch gpt with: size=base, dataset=v8, context_length=2048, tokenizer=CLMBRTokenizer, sampler=ApproxBatchSampler, max_tokens_per_batch=16384, use_cuda_devices=2,3, wandb_logging_name=gpt2-custom-run, force_restart_existing_run=True, save_to_path=/share/pi/nigam/mwornow/hf_ehr/cache/runs/bert-test/
python3 ../run.py \
+data=v8 \
+trainer=single_gpu \
+model=gpt2-base \
+tokenizer=clmbr \
data.dataloader.mode=approx \
data.dataloader.approx_batch_sampler.max_tokens=16384 \
data.dataloader.max_length=2048 \
model.config_kwargs.n_positions=2048 \
trainer.devices=[2,3] \
logging.wandb.name=gpt2-custom-run \
main.is_force_restart=True \
main.path_to_output_dir=/share/pi/nigam/mwornow/hf_ehr/cache/runs/bert-test/
How to Configure Runs
See the Config README for details on all config settings (models, training, dataloaders, tokenizers, etc.).
📊 Evaluation
EHRSHOT
How to use this repo with EHRSHOT.
1. Generate Patient Representations
This all occurs within the hf_ehr repo.
-
Identify the path (
<path_to_ckpt>) to the model checkpoint you want to evaluate. -
Generate patient representations with your model. This will create a folder in
/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/modelsfor this model checkpoint.
cd hf_ehr/scripts/eval/
sbatch ehrshot.sh <path_to_ckpt>
2. Generate EHRSHOT Results
This all occurs within the ehrshot-benchmark repo.
- Generate your model's AUROC/AUPRC results by running
7_eval.sh:
# cd to ehrshot-benchmark/ehrshot/bash_scripts/ directory
bash 7_eval.sh --is_use_slurm
3. Generate EHRSHOT Plots
This all occurs within the ehrshot-benchmark repo.
- Generate plots by running:
8_make_results_plots.sh. You might need to modify the--model_headsparameter in the file before running to specify what gets included in your plots.
# cd to ehrshot-benchmark/ehrshot/bash_scripts/ directory
bash 8_make_results_plots.sh
💊 MEDS Demo
We support training and inference on MEDS formatted datasets.
Here is a quick tutorial using the publicly available MIMIC-IV demo dataset (inspired by this tutorial).
- Download the MIMIC-IV demo dataset from PhysioNet.
export PATH_TO_DOWNLOAD=mimic4_demo
export PATH_TO_MEDS=meds_mimic4_demo
export PATH_TO_MEDS_READER=meds_mimic4_demo_reader
!wget -q -r -N -c --no-host-directories --cut-dirs=1 -np -P $PATH_TO_DOWNLOAD https://physionet.org/files/mimic-iv-demo/2.2/
- Convert the MIMIC-IV demo dataset to MEDS format.
rm -rf $PATH_TO_MEDS 2>/dev/null
meds_etl_mimic $PATH_TO_DOWNLOAD $PATH_TO_MEDS
- Convert the MEDS dataset into a MEDS Reader Database (to enable faster data ingestion during training).
rm -rf $PATH_TO_MEDS_READER 2>/dev/null
meds_reader_convert $PATH_TO_MEDS $PATH_TO_MEDS_READER --num_threads 4
- Verify everything worked.
meds_reader_verify $PATH_TO_MEDS $PATH_TO_MEDS_READER
- Create train/val/test splits (80/10/10) by running this Python script:
import meds_reader
import polars as pl
import os
database = meds_reader.SubjectDatabase(os.environ["PATH_TO_MEDS_READER"])
subject_ids = list(database)
splits = [
('train' if idx < 80 else 'tuning' if idx < 90 else 'held_out', subject_ids[idx])
for idx in range(len(subject_ids))
]
df = pl.DataFrame(splits, schema=["split", "subject_id"])
df.write_parquet(os.path.join(os.environ["PATH_TO_MEDS_READER"], 'metadata', 'subject_splits.parquet'))
- Create a Hydra config for your dataset.
cp hf_ehr/configs/data/meds_mimic4_demo.yaml hf_ehr/configs/meds/meds_mimic4_demo_custom.yaml
sed -i 's|/share/pi/nigam/mwornow/mimic-iv-demo-meds-reader|$PATH_TO_MEDS_READER|g' hf_ehr/configs/meds/meds_mimic4_demo_custom.yaml
- Train a Llama model on the dataset.
cd hf_ehr/scripts/carina
python3 main.py --model llama --size base --tokenizer clmbr --context_length 1024 --dataloader approx --dataset meds_mimic4_demo_custom --is_run_local --is_force_refresh
ℹ️ Other
Llama
The llama model checkpoints we saved only work with transformers 4.44.2.
Based
To get the based model to run, you need to do the following installations on an A100 or above node:
pip install -v \
--disable-pip-version-check \
--no-cache-dir \
--no-build-isolation \
--config-settings "--build-option=--cpp_ext" \
--config-settings "--build-option=--cuda_ext" \
'git+https://github.com/NVIDIA/apex@b496d85' --no-cache-dir
pip install --no-cache-dir \
torch==2.1.2 \
torchvision==0.16.2 \
torchaudio==2.1.2 \
--index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
# Install FLA triton kernel
pip install -U git+https://github.com/sustcsonglin/flash-linear-attention
pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2' --no-build-isolation --no-cache-dir
pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2#subdirectory=csrc/fused_dense_lib' --no-build-isolation --no-cache-dir
pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2#subdirectory=csrc/layer_norm' --no-build-isolation --no-cache-dir
git clone git@github.com:HazyResearch/based.git
cd based
pip install -e . --no-cache-dir
🤖 Creating a Model
Let's say we want to create a new model called {model} of size {size}.
-
Create the Hydra config YAML for your model architecture in
hf_ehr/configs/architecture/{model}.yaml. Copy the contents ofhf_ehr/configs/architecture/bert.yamland modify as needed. -
Create the Hydra config YAML for your model instantiation in
hf_ehr/configs/models/{model}-{size}.yaml. Copy the contents ofhf_ehr/configs/models/bert-base.yamland modify as needed. -
Create the model itself by creating a new file
hf_ehr/models/{model}.py. Copy the contents ofmodels/bert.pyand modify as needed. -
Add your model to
hf_ehr/scripts/run.pyabove the lineraise ValueError(f"Model{config.model.name}not supported.")
✂️ Creating a Tokenizer
See the Tokenizer README for details on creating tokenizers and how they are stored on the file system.
🤗 Uploading a Model to Hugging Face
See the Hugging Face README for details on uploading models to Hugging Face.
🎓 Citation
If you found this work useful, please consider citing it:
@article{wornow2024contextclues,
title={Context Clues: Evaluating Long Context Models for Clinical Prediction Tasks on EHRs},
author={Michael Wornow and Suhana Bedi and Miguel Angel Fuentes Hernandez and Ethan Steinberg and Jason Alan Fries and Christopher Ré and Sanmi Koyejo and Nigam H. Shah},
year={2024},
eprint={2412.16178},
url={https://arxiv.org/abs/2412.16178},
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file hf_ehr-0.1.2.tar.gz.
File metadata
- Download URL: hf_ehr-0.1.2.tar.gz
- Upload date:
- Size: 23.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
41a61ce985824b279065ccd4973d4a6d39fa2c092a0287ab48b2682384785ec8
|
|
| MD5 |
d0861bace513b0edee8b2d3c7892904d
|
|
| BLAKE2b-256 |
5d3022afa750fe39c418165f302bd4f84eb32e4538ab3680607c1a7ffbf123d2
|
File details
Details for the file hf_ehr-0.1.2-py3-none-any.whl.
File metadata
- Download URL: hf_ehr-0.1.2-py3-none-any.whl
- Upload date:
- Size: 18.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0eb2752ee2f3392957a245ed8506fe44c8b80c5c1866631825bef175a8eb7d2a
|
|
| MD5 |
3d206c9f23d7594a8567924366b099e8
|
|
| BLAKE2b-256 |
c156dfc038eaf488f86564f75eb5fba3e43a45db2dfaaac4d8355e10c041b481
|