Skip to main content

CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines

Project description

CEHRGPT

PyPI - Version Python tests license contributors

CEHRGPT is a multi-task foundation model for structured electronic health records (EHR) data that supports three capabilities: feature representation, zero-shot prediction, and synthetic data generation.

🎯 Key Capabilities

Feature Representation

Extract meaningful patient embeddings from sequences of medical events using linear probing techniques for downstream tasks such as disease prediction, patient clustering, and risk stratification.

Zero-Shot Prediction

Generate outcome predictions directly from prompts without requiring task-specific training, enabling rapid evaluation in low-label clinical settings.

Synthetic Data Generation

Generate comprehensive patient profiles including demographics, medical history, treatment courses, and outcomes while implementing advanced privacy-preserving techniques to ensure generated data contains no identifiable information. The platform is fully compatible with the OMOP Common Data Model for seamless integration with existing healthcare systems.

🚀 Installation

Clone the repository and install dependencies:

git clone https://github.com/knatarajan-lab/cehrgpt.git
cd cehrgpt
pip install .

📋 Prerequisites

Before getting started, set up the required environment variables:

export CEHRGPT_HOME=$(git rev-parse --show-toplevel)
export OMOP_DIR=""                    # Path to your OMOP data
export CEHR_GPT_DATA_DIR=""          # Path for processed data storage
export CEHR_GPT_MODEL_DIR=""         # Path for model storage

Create the dataset cache directory:

mkdir $CEHR_GPT_MODEL_DIR/dataset_prepared

🏗️ Model Training

Step 1: Generate Pre-training Data from OMOP

Generate the training data following the Data Generation Instruction.

Step 2: Pre-train CEHR-GPT

Train the foundation model:

python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
  --model_name_or_path $CEHR_GPT_MODEL_DIR \
  --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
  --output_dir $CEHR_GPT_MODEL_DIR \
  --data_folder "$CEHR_GPT_DATA_DIR/patient_sequence/train" \
  --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
  --do_train true --seed 42 \
  --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
  --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 4096 \
  --evaluation_strategy epoch --save_strategy epoch \
  --sample_packing --max_tokens_per_batch 16384 \
  --warmup_ratio 0.01 --weight_decay 0.01 \
  --num_train_epochs 50 --learning_rate 0.0002 \
  --use_early_stopping \
  --load_best_model_at_end true \
  --early_stopping_threshold 0.001

Tip: Increase max_position_embeddings for longer context windows based on your use case.

For DDP training, you need to launch the script:

torchrun --nproc_per_node=2 src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py \
  --model_name_or_path $CEHR_GPT_MODEL_DIR \
  --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
  --output_dir $CEHR_GPT_MODEL_DIR \
  --data_folder "$CEHR_GPT_DATA_DIR/patient_sequence/train" \
  --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
  --do_train true --seed 42 \
  --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
  --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 4096 \
  --evaluation_strategy epoch --save_strategy epoch \
  --sample_packing --max_tokens_per_batch 16384 \
  --warmup_ratio 0.01 --weight_decay 0.01 \
  --num_train_epochs 50 --learning_rate 0.0002 \
  --use_early_stopping \
  --load_best_model_at_end true \
  --early_stopping_threshold 0.001

To train large models using sharding with Deepspeed:

pip install deepspeed;
deepspeed --num_gpus=2 src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py \
  --model_name_or_path $CEHR_GPT_MODEL_DIR \
  --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
  --output_dir $CEHR_GPT_MODEL_DIR \
  --data_folder "$CEHR_GPT_DATA_DIR/patient_sequence/train" \
  --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
  --do_train true --seed 42 \
  --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
  --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 4096 \
  --evaluation_strategy epoch --save_strategy epoch \
  --sample_packing --max_tokens_per_batch 16384 \
  --warmup_ratio 0.01 --weight_decay 0.01 \
  --num_train_epochs 50 --learning_rate 0.0002 \
  --use_early_stopping \
  --load_best_model_at_end true \
  --early_stopping_threshold 0.001 \
  --deepspeed sample_configs/zero_stage3_config.json

🎯 Feature Representation

CEHR-GPT enables extraction of meaningful patient embeddings from medical event sequences using linear probing techniques for downstream prediction tasks. The feature representation pipeline includes label generation, patient sequence extraction, and linear regression model training on the extracted representations.

For detailed instructions including cohort creation, patient feature extraction, and linear probing evaluation, please follow the Feature Representation Guide.

🔮 Zero-Shot Prediction

CEHR-GPT can generate outcome predictions directly from clinical prompts without requiring task-specific training, making it ideal for rapid evaluation in low-label clinical settings. The zero-shot prediction capability performs time-to-event analysis by processing patient sequences and generating risk predictions based on learned medical patterns.

For complete setup instructions including label generation, sequence preparation, and prediction execution, please follow the Zero-Shot Prediction Guide.

🧬 Synthetic Data Generation

CEHR-GPT generates comprehensive synthetic patient profiles including demographics, medical history, treatment courses, and outcomes while implementing advanced privacy-preserving techniques. The synthetic data maintains statistical fidelity to real patient populations without containing identifiable information, and outputs are fully compatible with the OMOP Common Data Model.

For step-by-step instructions on generating synthetic sequences and converting them to OMOP format, please follow the Synthetic Data Generation Guide.

📊 MEDS Support

CEHR-GPT supports the Medical Event Data Standard (MEDS) format for enhanced interoperability.

Prerequisites

Configure MEDS-specific environment variables:

export CEHR_GPT_MODEL_DIR=""    # CEHR-GPT model directory
export MEDS_DIR=""              # MEDS data directory
export MEDS_READER_DIR=""       # MEDS reader output directory

Step 1: Create MIMIC MEDS Data

Transform MIMIC files to MEDS format following the MEDS_transforms repository instructions.

Step 2: Prepare MEDS Reader

Convert MEDS data for CEHR-GPT compatibility:

meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10

Step 3: Pre-train with MEDS Data

Execute pre-training using MEDS format:

python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
  --model_name_or_path $CEHR_GPT_MODEL_DIR \
  --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
  --output_dir $CEHR_GPT_MODEL_DIR \
  --data_folder $MEDS_READER_DIR \
  --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
  --do_train true --seed 42 \
  --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
  --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
  --evaluation_strategy epoch --save_strategy epoch \
  --sample_packing --max_tokens_per_batch 16384 \
  --warmup_steps 500 --weight_decay 0.01 \
  --num_train_epochs 50 --learning_rate 0.0002 \
  --use_early_stopping --early_stopping_threshold 0.001 \
  --is_data_in_meds --inpatient_att_function_type day \
  --att_function_type day --include_inpatient_hour_token \
  --include_auxiliary_token --include_demographic_prompt \
  --meds_to_cehrbert_conversion_type "MedsToBertMimic4"

Step 4: Generate MEDS Trajectories

Environment Setup

Configure trajectory generation environment:

export MEDS_LABEL_COHORT_DIR=""     # Cohort labels directory (parquet files)
export MEDS_TRAJECTORY_DIR=""       # Trajectory output directory

Generate Synthetic Trajectories

Create patient trajectories with the trained model:

python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
  --cohort_folder $MEDS_LABEL_COHORT_DIR \
  --data_folder $MEDS_READER_DIR \
  --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
  --model_name_or_path $CEHR_GPT_MODEL_DIR \
  --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
  --output_dir $MEDS_TRAJECTORY_DIR \
  --per_device_eval_batch_size 16 \
  --num_of_trajectories_per_sample 2 \
  --generation_input_length 4096 \
  --generation_max_new_tokens 4096 \
  --is_data_in_meds \
  --att_function_type day --inpatient_att_function_type day \
  --meds_to_cehrbert_conversion_type MedsToBertMimic4 \
  --include_auxiliary_token --include_demographic_prompt \
  --include_inpatient_hour_token

Important: Ensure generation_input_length + generation_max_new_tokensmax_position_embeddings (8192).

Parameter Reference

  • generation_input_length: Input context length for generation
  • generation_max_new_tokens: Maximum new tokens to generate
  • num_of_trajectories_per_sample: Number of trajectories per patient sample

📖 Citation

If you use CEHRGPT in your research, please cite:

@article{cehrgpt2024,
  title={CEHRGPT: Synthetic Data Generation for Electronic Health Records},
  author={Natarajan, K and others},
  journal={arXiv preprint arXiv:2402.04400},
  year={2024}
}

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cehrgpt-0.1.6.post4.tar.gz (6.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cehrgpt-0.1.6.post4-py3-none-any.whl (211.4 kB view details)

Uploaded Python 3

File details

Details for the file cehrgpt-0.1.6.post4.tar.gz.

File metadata

  • Download URL: cehrgpt-0.1.6.post4.tar.gz
  • Upload date:
  • Size: 6.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cehrgpt-0.1.6.post4.tar.gz
Algorithm Hash digest
SHA256 51311a348a327b287b20ba5cffbaee3c63d22f03541d500cbb3e1f3eb8a4691a
MD5 1c0aac837ce9c7042de8d966c0412f55
BLAKE2b-256 6fda5df2bced4e0be4c72f23e109c9107d30f7f4787820bec7d1d1cf21fc7e4b

See more details on using hashes here.

Provenance

The following attestation bundles were made for cehrgpt-0.1.6.post4.tar.gz:

Publisher: build-python.yaml on knatarajan-lab/cehrgpt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cehrgpt-0.1.6.post4-py3-none-any.whl.

File metadata

  • Download URL: cehrgpt-0.1.6.post4-py3-none-any.whl
  • Upload date:
  • Size: 211.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cehrgpt-0.1.6.post4-py3-none-any.whl
Algorithm Hash digest
SHA256 d06930b6878be5762ea9a155da87182e40b95437116df7616404981b40259c05
MD5 8e7dc807ba095b2d0b2bb7d6f0f8e716
BLAKE2b-256 150a3da0f866f9a65baf99921837dadb6b581e3c9a767d43f3527651875a385f

See more details on using hashes here.

Provenance

The following attestation bundles were made for cehrgpt-0.1.6.post4-py3-none-any.whl:

Publisher: build-python.yaml on knatarajan-lab/cehrgpt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page