CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
Project description
CEHRGPT
CEHRGPT is a multi-task foundation model for structured electronic health records (EHR) data that supports three capabilities: feature representation, zero-shot prediction, and synthetic data generation.
🎯 Key Capabilities
Feature Representation
Extract meaningful patient embeddings from sequences of medical events using linear probing techniques for downstream tasks such as disease prediction, patient clustering, and risk stratification.
Zero-Shot Prediction
Generate outcome predictions directly from prompts without requiring task-specific training, enabling rapid evaluation in low-label clinical settings.
Synthetic Data Generation
Generate comprehensive patient profiles including demographics, medical history, treatment courses, and outcomes while implementing advanced privacy-preserving techniques to ensure generated data contains no identifiable information. The platform is fully compatible with the OMOP Common Data Model for seamless integration with existing healthcare systems.
🚀 Installation
Clone the repository and install dependencies:
git clone https://github.com/knatarajan-lab/cehrgpt.git
cd cehrgpt
pip install .
📋 Prerequisites
Before getting started, set up the required environment variables:
export CEHRGPT_HOME=$(git rev-parse --show-toplevel)
export OMOP_DIR="" # Path to your OMOP data
export CEHR_GPT_DATA_DIR="" # Path for processed data storage
export CEHR_GPT_MODEL_DIR="" # Path for model storage
Create the dataset cache directory:
mkdir $CEHR_GPT_MODEL_DIR/dataset_prepared
🏗️ Model Training
Step 1: Generate Pre-training Data from OMOP
Generate the training data following the Data Generation Instruction.
Step 2: Pre-train CEHR-GPT
Train the foundation model:
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
--model_name_or_path $CEHR_GPT_MODEL_DIR \
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
--output_dir $CEHR_GPT_MODEL_DIR \
--data_folder "$CEHR_GPT_DATA_DIR/patient_sequence/train" \
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
--do_train true --seed 42 \
--dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 4096 \
--evaluation_strategy epoch --save_strategy epoch \
--sample_packing --max_tokens_per_batch 16384 \
--warmup_ratio 0.01 --weight_decay 0.01 \
--num_train_epochs 50 --learning_rate 0.0002 \
--use_early_stopping \
--load_best_model_at_end true \
--early_stopping_threshold 0.001
Tip: Increase
max_position_embeddingsfor longer context windows based on your use case.
For DDP training, you need to launch the script:
torchrun --nproc_per_node=2 src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py \
--model_name_or_path $CEHR_GPT_MODEL_DIR \
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
--output_dir $CEHR_GPT_MODEL_DIR \
--data_folder "$CEHR_GPT_DATA_DIR/patient_sequence/train" \
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
--do_train true --seed 42 \
--dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 4096 \
--evaluation_strategy epoch --save_strategy epoch \
--sample_packing --max_tokens_per_batch 16384 \
--warmup_ratio 0.01 --weight_decay 0.01 \
--num_train_epochs 50 --learning_rate 0.0002 \
--use_early_stopping \
--load_best_model_at_end true \
--early_stopping_threshold 0.001
To train large models using sharding with Deepspeed:
pip install deepspeed;
deepspeed --num_gpus=2 src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py \
--model_name_or_path $CEHR_GPT_MODEL_DIR \
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
--output_dir $CEHR_GPT_MODEL_DIR \
--data_folder "$CEHR_GPT_DATA_DIR/patient_sequence/train" \
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
--do_train true --seed 42 \
--dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 4096 \
--evaluation_strategy epoch --save_strategy epoch \
--sample_packing --max_tokens_per_batch 16384 \
--warmup_ratio 0.01 --weight_decay 0.01 \
--num_train_epochs 50 --learning_rate 0.0002 \
--use_early_stopping \
--load_best_model_at_end true \
--early_stopping_threshold 0.001 \
--deepspeed sample_configs/zero_stage3_config.json
🎯 Feature Representation
CEHR-GPT enables extraction of meaningful patient embeddings from medical event sequences using linear probing techniques for downstream prediction tasks. The feature representation pipeline includes label generation, patient sequence extraction, and linear regression model training on the extracted representations.
For detailed instructions including cohort creation, patient feature extraction, and linear probing evaluation, please follow the Feature Representation Guide.
🔮 Zero-Shot Prediction
CEHR-GPT can generate outcome predictions directly from clinical prompts without requiring task-specific training, making it ideal for rapid evaluation in low-label clinical settings. The zero-shot prediction capability performs time-to-event analysis by processing patient sequences and generating risk predictions based on learned medical patterns.
For complete setup instructions including label generation, sequence preparation, and prediction execution, please follow the Zero-Shot Prediction Guide.
🧬 Synthetic Data Generation
CEHR-GPT generates comprehensive synthetic patient profiles including demographics, medical history, treatment courses, and outcomes while implementing advanced privacy-preserving techniques. The synthetic data maintains statistical fidelity to real patient populations without containing identifiable information, and outputs are fully compatible with the OMOP Common Data Model.
For step-by-step instructions on generating synthetic sequences and converting them to OMOP format, please follow the Synthetic Data Generation Guide.
📊 MEDS Support
CEHR-GPT supports the Medical Event Data Standard (MEDS) format for enhanced interoperability.
Prerequisites
Configure MEDS-specific environment variables:
export CEHR_GPT_MODEL_DIR="" # CEHR-GPT model directory
export MEDS_DIR="" # MEDS data directory
export MEDS_READER_DIR="" # MEDS reader output directory
Step 1: Create MIMIC MEDS Data
Transform MIMIC files to MEDS format following the MEDS_transforms repository instructions.
Step 2: Prepare MEDS Reader
Convert MEDS data for CEHR-GPT compatibility:
meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
Step 3: Pre-train with MEDS Data
Execute pre-training using MEDS format:
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
--model_name_or_path $CEHR_GPT_MODEL_DIR \
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
--output_dir $CEHR_GPT_MODEL_DIR \
--data_folder $MEDS_READER_DIR \
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
--do_train true --seed 42 \
--dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
--evaluation_strategy epoch --save_strategy epoch \
--sample_packing --max_tokens_per_batch 16384 \
--warmup_steps 500 --weight_decay 0.01 \
--num_train_epochs 50 --learning_rate 0.0002 \
--use_early_stopping --early_stopping_threshold 0.001 \
--is_data_in_meds --inpatient_att_function_type day \
--att_function_type day --include_inpatient_hour_token \
--include_auxiliary_token --include_demographic_prompt \
--meds_to_cehrbert_conversion_type "MedsToBertMimic4"
Step 4: Generate MEDS Trajectories
Environment Setup
Configure trajectory generation environment:
export MEDS_LABEL_COHORT_DIR="" # Cohort labels directory (parquet files)
export MEDS_TRAJECTORY_DIR="" # Trajectory output directory
Generate Synthetic Trajectories
Create patient trajectories with the trained model:
python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
--cohort_folder $MEDS_LABEL_COHORT_DIR \
--data_folder $MEDS_READER_DIR \
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
--model_name_or_path $CEHR_GPT_MODEL_DIR \
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
--output_dir $MEDS_TRAJECTORY_DIR \
--per_device_eval_batch_size 16 \
--num_of_trajectories_per_sample 2 \
--generation_input_length 4096 \
--generation_max_new_tokens 4096 \
--is_data_in_meds \
--att_function_type day --inpatient_att_function_type day \
--meds_to_cehrbert_conversion_type MedsToBertMimic4 \
--include_auxiliary_token --include_demographic_prompt \
--include_inpatient_hour_token
Important: Ensure
generation_input_length+generation_max_new_tokens≤max_position_embeddings(8192).
Parameter Reference
generation_input_length: Input context length for generationgeneration_max_new_tokens: Maximum new tokens to generatenum_of_trajectories_per_sample: Number of trajectories per patient sample
📖 Citation
If you use CEHRGPT in your research, please cite:
@article{cehrgpt2024,
title={CEHRGPT: Synthetic Data Generation for Electronic Health Records},
author={Natarajan, K and others},
journal={arXiv preprint arXiv:2402.04400},
year={2024}
}
📄 License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cehrgpt-0.1.6.post4.tar.gz.
File metadata
- Download URL: cehrgpt-0.1.6.post4.tar.gz
- Upload date:
- Size: 6.0 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
51311a348a327b287b20ba5cffbaee3c63d22f03541d500cbb3e1f3eb8a4691a
|
|
| MD5 |
1c0aac837ce9c7042de8d966c0412f55
|
|
| BLAKE2b-256 |
6fda5df2bced4e0be4c72f23e109c9107d30f7f4787820bec7d1d1cf21fc7e4b
|
Provenance
The following attestation bundles were made for cehrgpt-0.1.6.post4.tar.gz:
Publisher:
build-python.yaml on knatarajan-lab/cehrgpt
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
cehrgpt-0.1.6.post4.tar.gz -
Subject digest:
51311a348a327b287b20ba5cffbaee3c63d22f03541d500cbb3e1f3eb8a4691a - Sigstore transparency entry: 568496105
- Sigstore integration time:
-
Permalink:
knatarajan-lab/cehrgpt@7f814a9063ab92b894409ee33e817008d5f3fe2e -
Branch / Tag:
refs/tags/v0.1.6.post4 - Owner: https://github.com/knatarajan-lab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
build-python.yaml@7f814a9063ab92b894409ee33e817008d5f3fe2e -
Trigger Event:
push
-
Statement type:
File details
Details for the file cehrgpt-0.1.6.post4-py3-none-any.whl.
File metadata
- Download URL: cehrgpt-0.1.6.post4-py3-none-any.whl
- Upload date:
- Size: 211.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d06930b6878be5762ea9a155da87182e40b95437116df7616404981b40259c05
|
|
| MD5 |
8e7dc807ba095b2d0b2bb7d6f0f8e716
|
|
| BLAKE2b-256 |
150a3da0f866f9a65baf99921837dadb6b581e3c9a767d43f3527651875a385f
|
Provenance
The following attestation bundles were made for cehrgpt-0.1.6.post4-py3-none-any.whl:
Publisher:
build-python.yaml on knatarajan-lab/cehrgpt
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
cehrgpt-0.1.6.post4-py3-none-any.whl -
Subject digest:
d06930b6878be5762ea9a155da87182e40b95437116df7616404981b40259c05 - Sigstore transparency entry: 568496139
- Sigstore integration time:
-
Permalink:
knatarajan-lab/cehrgpt@7f814a9063ab92b894409ee33e817008d5f3fe2e -
Branch / Tag:
refs/tags/v0.1.6.post4 - Owner: https://github.com/knatarajan-lab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
build-python.yaml@7f814a9063ab92b894409ee33e817008d5f3fe2e -
Trigger Event:
push
-
Statement type: