A simple auto-regressive, 'everything-is-code' style model for MEDS datasets
Project description
MEDS "Everything-is-code" Autoregressive Model
A MEDS, "Everything-is-code" style Autoregressive Generative Model, capable of zero-shot inference.
Installation
pip install MEDS-EIC-AR
Optional Dependencies
WandB
If you want to use WandB for logging, you can install it via:
pip install MEDS-EIC-AR[wandb]
MLFlow
If you want to use MLFlow for logging, you can install it via:
pip install MEDS-EIC-AR[mlflow]
This will also install psutil and pynvml as dependencies, to enable MLFlow tracking of system CPU and GPU
resources, which is enabled by default or can be controlled via the MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING
environment variable. See the
MLFlow documentation for
more details.
Flash Attention
For using flash attention, you need to subsequently install flash attention as well. This can often be done via:
pip install flash-attn --no-build-isolation
If you encounter errors, see the flash-attn package documentation.
Usage
1. Pre-process your data
You have three directories:
$RAW_MEDS_DIR-- The raw MEDS data directory that you want to pre-process.$INTERMEDIATE_DIR-- An intermediate directory where the partially processed data will be stored prior to tokenization and tensorization.$FINAL_DATA_DIR-- The final output directory where the tokenized and tensorized data will be stored. This directory is suitable for use in loading the data withmeds-torch-data.
Run:
MEICAR_process_data input_dir="$RAW_MEDS_DIR" \
intermediate_dir="$INTERMEDIATE_DIR" \
output_dir="$FINAL_DATA_DIR"
[!NOTE] If your data is not sharded by split at the outset, you will need to add the
do_reshard=Truecommand line parameter to theMEICAR_process_datacommand, which ensures the system reshards the data to be sub-sharded by split before beginning pre-processing.
You can also run this in demo mode, which lowers the filtering thresholds significantly so the script does not filter out all data:
MEICAR_process_data ... do_demo=True
You can exert more fine-grained control on the filtering with the following environment variables:
MIN_SUBJECTS_PER_CODE: How many subjects must a given code be observed within to be included in the final vocabulary? Note that this excludes some sentinel codes which are always retained.MIN_EVENTS_PER_SUBJECT: How many events must a subject have to be included in the final dataset?
2. Pre-train the model
You can pre-train the model using the MEICAR_pretrain command. To use this, let us assume you have a new
directory to store the pretrained model artifacts called $PRETRAINED_MODEL_DIR. Then, you can run:
MEICAR_pretrain datamodule.config.tensorized_cohort_dir="$FINAL_DATA_DIR" \
output_dir="$PRETRAINED_MODEL_DIR" \
datamodule.batch_size=32
to train the model for 10 epochs.
This uses a Hydra configuration system, with the root config located in the
_pretrain.yaml file. You can override any of the nested
configuration parameters (as shown above via datamodule.config.tensorized_cohort_dir on the command line,
though you will more likely materialize an experimental configuration file to disk in yaml form and overwrite
the config path and name directly in the normal hydra manner.
[!WARNING] Tests here only validate that the model runs without errors and (in demo mode) runs without producing nans or invalid values. It has not yet been assessed to ensure it runs to convergence, etc.
3. Zero-shot Inference
Zero-shot inference consists of two steps:
- Given a task cohort and a pre-trained model, for each sample in the task cohort, generate future trajectories from those inputs forward with the pre-trained model and save them to disk in a pseudo-MEDS format.
- Resolve these generated trajectories into concrete, probabilistic predictions for the task cohort.
3.1 Generate Trajectories for a task spec.
You can directly generate trajectories using the MEICAR_generate_trajectories command. This requires a few
more configuration parameters than the pre-training step, so let's go through those:
- You need to specify the task labels directory in the
datamodule.config.task_labels_dirparameter. - You need to specify the model initialization directory in the
model_initialization_dirparameter. This is the output directory of the pre-train step. - You need to specify how you want to trade-off between allowed input context size and the maximum possible
generated trajectory length. The former allows you to use more of the patient's record, but the latter
controls how far into the future you can predict. This can be configured with one of three parameters in
the
seq_lenspart of the config. If you set:seq_lens.generation_context_size, that will be the maximum length of the input context, and the remaining length of the pretrained model's maximum sequence length will be used for generation.seq_lens.max_generated_trajectory_len, that will be the maximum length of the generated trajectory, and the remaining length of the pretrained model's maximum sequence length will be used for the input.seq_lens.frac_seq_len_as_context, that will be the fraction of the pretrained model's maximum sequence length that will be used for the input context, and the remaining length will be used for generation. This is set by default to 0.25, which means that 25% of the maximum sequence length will be used for the input context, and 75% will be used for generation. If you wish to use another mode on the command line, be sure to set this tonullto disable it.
- Lastly, you need to specify how many trajectories per task sample you wish to generate, and for which
splits you wish to generate samples. You can do this via the
inference.generate_for_splitsandinference.N_trajectories_per_task_sampleparameters. The former is a list of splits to generate and the latter is the number of trajectories to generate per task sample. The default is to generate 20 trajectories for each task sample in the tuning and held out splits. Each subject's N trajectories are interleaved into a single predict pass (seegeneration/repeated_dataset.pyand issue #89), rather than run as N independent passes. - If your desired trajectory length exceeds the model's per-chunk context, generation automatically
switches to a rolling sliding-window path (see
Model._rolling_generate). Setrolling_generation.max_new_tokensto bound the total generated length. EOS and the step cap both terminate generation early; between chunks the context is re-primed with the tail of the running sequence.
After these are set, you can run the following command to generate trajectories for a task cohort:
MEICAR_generate_trajectories \
output_dir="$GENERATED_TRAJECTORIES_DIR" \
model_initialization_dir="$PRETRAINED_MODEL_DIR" \
datamodule.config.tensorized_cohort_dir="$FINAL_DATA_DIR" \
datamodule.config.task_labels_dir="$TASK_ROOT_DIR/$TASK_NAME" \
datamodule.batch_size=32
This will generate trajectories for the task cohort and save them in the format:
$GENERATED_TRAJECTORIES_DIR/$SPLIT/$SAMPLE.parquet.
See the documentation for format_trajectories for more
details on the format of the generated trajectories.
[!WARNING] The tests here only validate that this runs without errors and produces trajectory files that are valid, non-identical across different samples, and containing the right subjects. It has not yet been assessed to ensure full correctness.
[!NOTE] The generated trajectories from this model are saved in the schema defined in the
MEDS_trajectory_evaluation.schema.GeneratedTrajectorySchemaformat, and can be used with that package's evaluation tools.
3.2 Resolve Trajectories into Predictions.
Not yet implemented.
Documentation
Configuration and Controlling Model Structure
This model is configured via Hydra and PyTorch lightning. The configuration structure of this repository is as follows:
>>> print_directory("./src/MEDS_EIC_AR/configs", config=PrintConfig(file_extension=".yaml"))
├── _demo_generate_trajectories.yaml
├── _demo_pretrain.yaml
├── _generate_trajectories.yaml
├── _pretrain.yaml
├── datamodule
│ ├── default.yaml
│ ├── generate_trajectories.yaml
│ └── pretrain.yaml
├── inference
│ ├── default.yaml
│ └── demo.yaml
├── lightning_module
│ ├── LR_scheduler
│ │ └── get_cosine_schedule_with_warmup.yaml
│ ├── default.yaml
│ ├── demo.yaml
│ ├── large.yaml
│ ├── medium.yaml
│ ├── metrics
│ │ └── default.yaml
│ ├── micro.yaml
│ ├── model
│ │ ├── default.yaml
│ │ ├── demo.yaml
│ │ ├── large.yaml
│ │ ├── medium.yaml
│ │ ├── micro.yaml
│ │ └── small.yaml
│ ├── optimizer
│ │ └── adamw.yaml
│ └── small.yaml
└── trainer
├── callbacks
│ ├── default.yaml
│ ├── early_stopping.yaml
│ ├── generation.yaml
│ ├── generation_speed_logger.yaml
│ ├── learning_rate_monitor.yaml
│ └── model_checkpoint.yaml
├── default.yaml
├── demo.yaml
├── demo_generate.yaml
├── generate.yaml
└── logger
├── csv.yaml
├── mlflow.yaml
└── wandb.yaml
Logging with wandb
You can activate the wandb logger by overriding the trainer logger to wandb:
MEICAR_pretrain trainer.logger=wandb
The configuration file configs/trainer/logger/wandb.yaml
exposes a tags field. Hydra makes the selected configuration groups available
via hydra.runtime.choices. These can be referenced to automatically tag the
run. For example:
tags:
- ${hydra:runtime.choices.lightning_module/model}
This automatically tags each run with the selected model size (e.g. small,
medium, large). Hydra currently cannot append to a list with a default
value. To add your own tags you must override the list and include the default
tag yourself:
MEICAR_pretrain trainer.logger=wandb \
trainer.logger.tags="[${hydra:runtime.choices.lightning_module/model},experiment-1]"
This results in the tags [model_size, "experiment-1"] being sent to wandb.
Inference Backend
Generation runs through a pluggable backend abstraction (GenerationBackend protocol in
src/MEDS_EIC_AR/model/backends/). The default is HFBackend, which
wraps LlamaForCausalLM.generate in-process. A non-HF backend (e.g. SGLang, in-flight at #117) can drop
in behind the same interface without touching the rolling-generation loop or the trajectory-format
pipeline. See issue #88 for motivation and the roadmap.
Output Files
The output files of the pre-training step are stored in the directory specified by the output_dir parameter
and take the following structure:
>>> print_directory(pretrained_model)
├── .logs
│ ├── .hydra
│ │ ├── config.yaml
│ │ ├── hydra.yaml
│ │ └── overrides.yaml
│ └── __main__.log
├── best_model.ckpt
├── checkpoints
│ ├── epoch=0-step=1.ckpt
│ ├── epoch=0-step=2.ckpt
│ ├── epoch=1-step=3.ckpt
│ ├── epoch=1-step=4.ckpt
│ └── last.ckpt
├── config.yaml
├── environment.txt
├── loggers
│ └── csv
│ └── version_0
│ ├── hparams.yaml
│ └── metrics.csv
└── resolved_config.yaml
The files worth calling out:
config.yaml— the Hydra config as-resolved at the start of the run. Used by resume to verify no load-bearing param has changed.resolved_config.yaml— the same config with Hydra interpolations resolved. Useful for comparing runs and for downstream tooling that cannot resolve Hydra syntax.environment.txt— a pip-freeze-style snapshot of the Python environment at training time (Python version, platform, and every installed distribution and version), written viasave_environment_snapshot. Only written on initial run creation, not on resume. Useful for reproducing a run and for tracking down environment drift between training and inference.best_model.ckpt— a copy of the best checkpoint according to the model-checkpoint callback (written viashutil.copyfile, not a symlink, so the run directory is self-contained for rsync / archive).checkpoints/— Lightning's per-step / per-epoch checkpoints pluslast.ckpt..logs/.hydra/— Hydra's run metadata.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file meds_eic_ar-0.3.0.tar.gz.
File metadata
- Download URL: meds_eic_ar-0.3.0.tar.gz
- Upload date:
- Size: 356.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8a39d0e1a363f95047c84cbffcf724c695185a1affacc246abd775d5bc3c7e64
|
|
| MD5 |
6a6b5b9a01f2d65f5703f1d2d4cb2574
|
|
| BLAKE2b-256 |
06809940ea362f3b7fbeb9271682b63a2ffa7a6425f4de156c652406c830283d
|
Provenance
The following attestation bundles were made for meds_eic_ar-0.3.0.tar.gz:
Publisher:
python-build.yaml on mmcdermott/MEDS_EIC_AR
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
meds_eic_ar-0.3.0.tar.gz -
Subject digest:
8a39d0e1a363f95047c84cbffcf724c695185a1affacc246abd775d5bc3c7e64 - Sigstore transparency entry: 1372887327
- Sigstore integration time:
-
Permalink:
mmcdermott/MEDS_EIC_AR@c2c00b03731d23967461fe14c58c54ee75e7ab22 -
Branch / Tag:
refs/tags/0.3.0 - Owner: https://github.com/mmcdermott
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-build.yaml@c2c00b03731d23967461fe14c58c54ee75e7ab22 -
Trigger Event:
push
-
Statement type:
File details
Details for the file meds_eic_ar-0.3.0-py3-none-any.whl.
File metadata
- Download URL: meds_eic_ar-0.3.0-py3-none-any.whl
- Upload date:
- Size: 100.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d414b70163b26438890f0830711e54b7414f75acc43210d4bf3473eff7b71b41
|
|
| MD5 |
f18faa482d40f1150d43f59255e85923
|
|
| BLAKE2b-256 |
06eb81963f3b057c9e3c6674aa13fdcb430a8ad79a9dc95954c6b040ade83281
|
Provenance
The following attestation bundles were made for meds_eic_ar-0.3.0-py3-none-any.whl:
Publisher:
python-build.yaml on mmcdermott/MEDS_EIC_AR
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
meds_eic_ar-0.3.0-py3-none-any.whl -
Subject digest:
d414b70163b26438890f0830711e54b7414f75acc43210d4bf3473eff7b71b41 - Sigstore transparency entry: 1372887481
- Sigstore integration time:
-
Permalink:
mmcdermott/MEDS_EIC_AR@c2c00b03731d23967461fe14c58c54ee75e7ab22 -
Branch / Tag:
refs/tags/0.3.0 - Owner: https://github.com/mmcdermott
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-build.yaml@c2c00b03731d23967461fe14c58c54ee75e7ab22 -
Trigger Event:
push
-
Statement type: