An efficient, flexible PyTorch dataset class for MEDS data.
Project description
MEDS TorchData: A PyTorch Dataset Class for MEDS Datasets
๐ Quick Start
Step 1: Install
pip install meds-torch-data
Step 2: Data Tensorization
[!WARNING] If your dataset is not sharded by split, you need to run a reshard to split stage first! You can enable this by adding the
do_reshard=Trueargument to the command below.
If your input MEDS dataset lives in $MEDS_ROOT and you want to store your pre-processed files in
$PYD_ROOT, you run:
MTD_preprocess MEDS_dataset_dir="$MEDS_ROOT" output_dir="$PYD_ROOT"
Step 3: Use the dataset
To use a dataset, you need to (1) define your configuration object and (2) create the dataset object. The only
required configuration parameters are tensorized_cohort_dir, which points to the root directory containing
the pre-processed data on disk ($PYD_ROOT in the above example), and max_seq_len, which is the maximum
sequence length you want to use for your model. Here's an example:
import os
from meds_torchdata import MEDSPytorchDataset, MEDSTorchDataConfig
cfg = MEDSTorchDataConfig(tensorized_cohort_dir=os.environ["PYD_ROOT"], max_seq_len=512)
pyd = MEDSPytorchDataset(cfg, split="train")
If you want to use a specific binary classification task, you can add the task_labels_dir parameter to the
configuration object. This should point to a directory containing the sharded MEDS label format parquet files
for the labels. The sharding scheme is arbitrary and will not be reflected in the dataset.
That's it!
[!NOTE] Only binary classification tasks are supported at this time. If you need multi-class classification or other kinds of tasks, please file a GitHub issue
๐ Documentation
Design Principles
A good PyTorch dataset class should:
- Be easy to use
- Have a minimal, constant resource footprint (memory, CPU, start-up time) during model training and inference, regardless of the overall dataset size.
- Perform as much work as possible in static, reusable dataset pre-processing, rather than upon
construction or in the
__getitem__method. - Induce effectively negligible computational overhead in the
__getitem__method relative to model training. - Be easily configurable, with a simple, consistent API, and cover the most common use-cases.
- Encourage efficient use of GPU resources in the resulting batches.
- Should be comprehensively documented, tested, and benchmarked for performance implications so users can use it reliably and effectively.
To achieve this, MEDS TorchData leverages the following design principles:
- Lazy Loading: Data is loaded only when needed, and only the data needed for the current batch is loaded.
- Efficient Loading: Data is loaded efficiently leveraging the HuggingFace Safetensors library for raw IO through the nested, ragged interface encoded in the Nested Ragged Tensors library.
- Configurable, Transparent Pre-processing: Mandatory data pre-processing prior to effective use in this library is managed through a simple MEDS-Transforms pipeline which can be run on any MEDS dataset, after any model-specific pre-processing, via a transparent configuration file.
- Continuous Integration: The library is continuously tested and benchmarked for performance implications, and the results are available to users.
Examples and Detailed Usage
To see how this works, let's look at some examples. These examples will be powered by some synthetic data defined as "fixtures" in this package's pytest stack; namely, we'll use the following fixtures:
simple_static_MEDS: This will point to a Path containing a simple MEDS dataset.simple_static_MEDS_dataset_with_task: This will point to a Path containing a simple MEDS dataset with a boolean-value task defined. The core data is the same between both thesimple_static_MEDSand this dataset, but the latter has a task defined.tensorized_MEDS_datasetfixture that points to a Path containing the tensorized and schema files for thesimple_static_MEDSdataset.tensorized_MEDS_dataset_with_taskfixture that points to a tuple containing:- A Path containing the tensorized and schema files for the
simple_static_MEDS_dataset_with_taskdataset - A Path pointing to the root task directory for the dataset
- The specific task name for the dataset. Task label files will be stored in a subdir of the root task directory with this name.
- A Path containing the tensorized and schema files for the
You can find these in either the conftest.py file for this repository or the
meds_testing_helpers package, which
this package leverages for testing.
Synthetic Data
To start, let's take a look at this synthetic data. It is sharded by split, and we'll look at the train split first, which has two shards (we convert to polars just for prettier printing). It has four subjects across the two shards:
>>> import polars as pl
>>> from meds_testing_helpers.dataset import MEDSDataset
>>> D = MEDSDataset(root_dir=simple_static_MEDS)
>>> train_0 = pl.from_arrow(D.data_shards["train/0"])
>>> train_0
shape: (30, 4)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โ subject_id โ time โ code โ numeric_value โ
โ --- โ --- โ --- โ --- โ
โ i64 โ datetime[ฮผs] โ str โ f32 โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโก
โ 239684 โ null โ EYE_COLOR//BROWN โ null โ
โ 239684 โ null โ HEIGHT โ 175.271118 โ
โ 239684 โ 1980-12-28 00:00:00 โ DOB โ null โ
โ 239684 โ 2010-05-11 17:41:51 โ ADMISSION//CARDIAC โ null โ
โ 239684 โ 2010-05-11 17:41:51 โ HR โ 102.599998 โ
โ โฆ โ โฆ โ โฆ โ โฆ โ
โ 1195293 โ 2010-06-20 20:24:44 โ HR โ 107.699997 โ
โ 1195293 โ 2010-06-20 20:24:44 โ TEMP โ 100.0 โ
โ 1195293 โ 2010-06-20 20:41:33 โ HR โ 107.5 โ
โ 1195293 โ 2010-06-20 20:41:33 โ TEMP โ 100.400002 โ
โ 1195293 โ 2010-06-20 20:50:04 โ DISCHARGE โ null โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโ
>>> train_1 = pl.from_arrow(D.data_shards["train/1"])
>>> train_1
shape: (14, 4)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โ subject_id โ time โ code โ numeric_value โ
โ --- โ --- โ --- โ --- โ
โ i64 โ datetime[ฮผs] โ str โ f32 โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโก
โ 68729 โ null โ EYE_COLOR//HAZEL โ null โ
โ 68729 โ null โ HEIGHT โ 160.395309 โ
โ 68729 โ 1978-03-09 00:00:00 โ DOB โ null โ
โ 68729 โ 2010-05-26 02:30:56 โ ADMISSION//PULMONARY โ null โ
โ 68729 โ 2010-05-26 02:30:56 โ HR โ 86.0 โ
โ โฆ โ โฆ โ โฆ โ โฆ โ
โ 814703 โ 1976-03-28 00:00:00 โ DOB โ null โ
โ 814703 โ 2010-02-05 05:55:39 โ ADMISSION//ORTHOPEDIC โ null โ
โ 814703 โ 2010-02-05 05:55:39 โ HR โ 170.199997 โ
โ 814703 โ 2010-02-05 05:55:39 โ TEMP โ 100.099998 โ
โ 814703 โ 2010-02-05 07:02:30 โ DISCHARGE โ null โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโ
>>> sorted(set(train_0["subject_id"].unique()) | set(train_1["subject_id"].unique()))
[68729, 239684, 814703, 1195293]
MEDSTorchDataConfig Configuration Object
Full API documentation for the configuration object can be found here.
The configuration object contains two kinds of parameters: Data processing parameters and file paths. Data processing parameters include:
max_seq_len: The maximum sequence length to use for the model.seq_sampling_strategy: The strategy to use when sampling sub-sequences to return for input sequences longer thanmax_seq_len.static_inclusion_mode: The mode to use when including static data in the output.batch_mode: Whether to return sequences at the measurement level ("SM") or the event level ("SEM"). Note that here, we use "measurement" to refer to a single row (observation) in the raw MEDS data, and "event" to refer to all measurements taken at a single time-point.include_window_last_observed_in_schema: IfTrue, include the timestamp of the last observation in each sampled window in the dataset'sschema_dfwhen an index dataframe is used and the sampling strategy is deterministic. This functionality is useful for generative applications where the model needs to know what the timestamp is at the start of a generation window, for example.
Of these, seq_sampling_strategy and static_inclusion_mode are restricted, and must be of the
SubsequenceSamplingStrategy
and
StaticInclusionMode
StrEnums, respectively:
seq_sampling_strategy: One of["random", "to_end", "from_start"](defaults to"random").static_inclusion_mode: One of["include", "prepend", "omit"](defaults to"include").
File path parameters include:
tensorized_cohort_dir: The directory containing the tensorized data.task_labels_dir: The directory containing the task labels files.
It also provides a convenient property to get the vocab size for the dataset, given by the vocab indices in the tensorized metadata. Let's start by building a configuration object for this data and inspect some of its file-path related properties and helpers:
>>> from meds_torchdata import MEDSTorchDataConfig
>>> cfg = MEDSTorchDataConfig(tensorized_MEDS_dataset, max_seq_len=5)
>>> cfg.tensorized_cohort_dir
PosixPath('/tmp/tmp...')
>>> cfg.schema_dir
PosixPath('/tmp/tmp.../tokenization/schemas')
>>> print(sorted(list(cfg.schema_fps)))
[('held_out/0', PosixPath('/tmp/tmp.../tokenization/schemas/held_out/0.parquet')),
('train/0', PosixPath('/tmp/tmp.../tokenization/schemas/train/0.parquet')),
('train/1', PosixPath('/tmp/tmp.../tokenization/schemas/train/1.parquet')),
('tuning/0', PosixPath('/tmp/tmp.../tokenization/schemas/tuning/0.parquet'))]
>>> print(cfg.task_labels_dir)
None
>>> print(cfg.task_labels_fps)
None
>>> print(cfg.vocab_size)
12
If we specify a task_labels_dir parameter, the config operates in task-specific mode. This allows us to use
the task-specific helpers, but it also mandates we set seq_sampling_strategy to "to_end" as you shouldn't
try to predict a downstream task without leveraging the most recent data.
>>> cohort_dir, tasks_dir, task_name = tensorized_MEDS_dataset_with_task
>>> cfg = MEDSTorchDataConfig(
... cohort_dir, max_seq_len=5, task_labels_dir=(tasks_dir / task_name)
... )
Traceback (most recent call last):
...
ValueError: Not sampling data till the end of the sequence when predicting for a specific task is not
permitted! This is because there is no use-case we know of where you would want to do this. If you disagree,
please let us know via a GitHub issue.
>>> cfg = MEDSTorchDataConfig(
... cohort_dir, max_seq_len=5, task_labels_dir=(tasks_dir / task_name), seq_sampling_strategy="to_end"
... )
>>> cfg.task_labels_dir
PosixPath('/tmp/tmp.../task_labels/boolean_value_task')
>>> print(list(cfg.task_labels_fps))
[PosixPath('/tmp/tmp.../task_labels/boolean_value_task/labels_A.parquet.parquet'),
PosixPath('/tmp/tmp.../task_labels/boolean_value_task/labels_B.parquet.parquet')]
Based on the seq_sampling_strategy, batch_mode, and max_seq_len parameters, the configuration
object also has the
process_dynamic_data
helper function to slice the subject's dynamic data appropriately. This function is used internally, and you
will not need to use it yourself.
MEDSPytorchDataset Dataset Class
Full API documentation for the dataset class can be found here.
Now let's build a dataset object from the synthetic data.
Dataset "Schema"
When we build a PyTorch dataset from it for training, with no task specified, the length will be four, as it
will correspond to each of the four subjects in the train split. The index variable contains the list of
subject IDs and the end of the allowed region of reading for the dataset. We can also see it in dataframe
format via the schema_df:
>>> from meds_torchdata import MEDSPytorchDataset
>>> cfg = MEDSTorchDataConfig(tensorized_cohort_dir=tensorized_MEDS_dataset, max_seq_len=5)
>>> pyd = MEDSPytorchDataset(cfg, split="train")
>>> len(pyd)
4
>>> pyd.index
[(239684, 6), (1195293, 8), (68729, 3), (814703, 3)]
>>> pyd.schema_df
shape: (4, 2)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโ
โ subject_id โ end_event_index โ
โ --- โ --- โ
โ i64 โ u32 โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโก
โ 239684 โ 6 โ
โ 1195293 โ 8 โ
โ 68729 โ 3 โ
โ 814703 โ 3 โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโ
Note the index is in terms of event indices, not measurement indices -- meaning it is the index of the
unique timestamp corresponding to the start and end of each subject's data; not the unique measurement. We can
validate that against the raw data. To do so, we'll define the simple helper function get_event_bounds that
will just group by the subject_id and time columns, and then calculate the event index for each subject
and show us the min and max such index, per-subject.
>>> def get_event_indices(df: pl.DataFrame) -> pl.DataFrame:
... return (
... df
... .group_by("subject_id", "time", maintain_order=True).agg(pl.len().alias("n_measurements"))
... .with_row_index()
... .select(
... "subject_id", "time",
... (pl.col("index") - pl.col("index").min().over("subject_id")).alias("event_idx"),
... "n_measurements",
... )
... )
>>> def get_event_bounds(df: pl.DataFrame) -> pl.DataFrame:
... return (
... get_event_indices(df)
... .with_columns(
... pl.col("event_idx").max().over("subject_id").alias("max_event_idx")
... )
... .filter((pl.col("event_idx") == 0) | (pl.col("event_idx") == pl.col("max_event_idx")))
... .select("subject_id", "event_idx", "time")
... )
>>> get_event_bounds(train_0)
shape: (4, 3)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโ
โ subject_id โ event_idx โ time โ
โ --- โ --- โ --- โ
โ i64 โ u32 โ datetime[ฮผs] โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโก
โ 239684 โ 0 โ null โ
โ 239684 โ 6 โ 2010-05-11 19:27:19 โ
โ 1195293 โ 0 โ null โ
โ 1195293 โ 8 โ 2010-06-20 20:50:04 โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโ
>>> get_event_bounds(train_1)
shape: (4, 3)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโ
โ subject_id โ event_idx โ time โ
โ --- โ --- โ --- โ
โ i64 โ u32 โ datetime[ฮผs] โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโก
โ 68729 โ 0 โ null โ
โ 68729 โ 3 โ 2010-05-26 04:51:52 โ
โ 814703 โ 0 โ null โ
โ 814703 โ 3 โ 2010-02-05 07:02:30 โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโ
The schema changes to reflect the different split if we change the split:
>>> pyd_tuning = MEDSPytorchDataset(cfg, split="tuning")
>>> pyd_tuning.schema_df
shape: (1, 2)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโ
โ subject_id โ end_event_index โ
โ --- โ --- โ
โ i64 โ u32 โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโก
โ 754281 โ 3 โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโ
>>> pyd_held_out = MEDSPytorchDataset(cfg, split="held_out")
>>> pyd_held_out.schema_df
shape: (1, 2)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโ
โ subject_id โ end_event_index โ
โ --- โ --- โ
โ i64 โ u32 โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโก
โ 1500733 โ 5 โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโ
If you use a non-existent split or have something misconfigured, you'll get an error upon Dataset creation:
>>> pyd_bad = MEDSPytorchDataset(cfg, split="bad_split")
Traceback (most recent call last):
...
FileNotFoundError: No schema files found in /tmp/.../tokenization/schemas! If your data is not sharded by
split, this error may occur because this codebase does not handle non-split sharded data. See Issue #79 for
tracking this issue.
We can also inspect the schema for a dataset built with downstream task labels:
>>> cohort_dir, tasks_dir, task_name = tensorized_MEDS_dataset_with_task
>>> cfg_with_task = MEDSTorchDataConfig(
... cohort_dir, max_seq_len=5, task_labels_dir=(tasks_dir / task_name), seq_sampling_strategy="to_end"
... )
>>> pyd_with_task = MEDSPytorchDataset(cfg_with_task, split="train")
>>> pyd_with_task.schema_df
shape: (13, 4)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โ subject_id โ end_event_index โ prediction_time โ boolean_value โ
โ --- โ --- โ --- โ --- โ
โ i64 โ u32 โ datetime[ฮผs] โ bool โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโก
โ 239684 โ 3 โ 2010-05-11 18:00:00 โ false โ
โ 239684 โ 4 โ 2010-05-11 18:30:00 โ true โ
โ 239684 โ 5 โ 2010-05-11 19:00:00 โ true โ
โ 1195293 โ 3 โ 2010-06-20 19:30:00 โ false โ
โ 1195293 โ 4 โ 2010-06-20 20:00:00 โ true โ
โ โฆ โ โฆ โ โฆ โ โฆ โ
โ 68729 โ 2 โ 2010-05-26 04:00:00 โ true โ
โ 68729 โ 2 โ 2010-05-26 04:30:00 โ true โ
โ 814703 โ 2 โ 2010-02-05 06:00:00 โ false โ
โ 814703 โ 2 โ 2010-02-05 06:30:00 โ true โ
โ 814703 โ 2 โ 2010-02-05 07:00:00 โ true โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโ
When we have a task or index dataframe (an index is just a task without a label), we can also ask the model to
include the last observed time in our input window in the schema, with the include_window_last_observed_in_schema
parameter:
>>> cfg_with_end_time = MEDSTorchDataConfig(
... cohort_dir, max_seq_len=5, task_labels_dir=(tasks_dir / task_name), seq_sampling_strategy="to_end",
... include_window_last_observed_in_schema=True
... )
>>> pyd_with_end_time = MEDSPytorchDataset(cfg_with_end_time, split="train")
>>> pyd_with_end_time.schema_df
shape: (13, 5)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโ
โ subject_id โ end_event_index โ prediction_time โ boolean_value โ window_last_observed โ
โ --- โ --- โ --- โ --- โ --- โ
โ i64 โ u32 โ datetime[ฮผs] โ bool โ datetime[ฮผs] โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโโก
โ 239684 โ 3 โ 2010-05-11 18:00:00 โ false โ 2010-05-11 17:48:48 โ
โ 239684 โ 4 โ 2010-05-11 18:30:00 โ true โ 2010-05-11 18:25:35 โ
โ 239684 โ 5 โ 2010-05-11 19:00:00 โ true โ 2010-05-11 18:57:18 โ
โ 1195293 โ 3 โ 2010-06-20 19:30:00 โ false โ 2010-06-20 19:25:32 โ
โ 1195293 โ 4 โ 2010-06-20 20:00:00 โ true โ 2010-06-20 19:45:19 โ
โ โฆ โ โฆ โ โฆ โ โฆ โ โฆ โ
โ 68729 โ 2 โ 2010-05-26 04:00:00 โ true โ 2010-05-26 02:30:56 โ
โ 68729 โ 2 โ 2010-05-26 04:30:00 โ true โ 2010-05-26 02:30:56 โ
โ 814703 โ 2 โ 2010-02-05 06:00:00 โ false โ 2010-02-05 05:55:39 โ
โ 814703 โ 2 โ 2010-02-05 06:30:00 โ true โ 2010-02-05 05:55:39 โ
โ 814703 โ 2 โ 2010-02-05 07:00:00 โ true โ 2010-02-05 05:55:39 โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโ
Returned items
While the raw data has codes as strings, naturally, when embedded in the pytorch dataset, they'll get converted to integers. This happens during the forementioned tensorization step. We can see how the codes are mapped to integers by looking at the output code metadata of that step:
>>> code_metadata = pl.read_parquet(tensorized_MEDS_dataset.joinpath("metadata/codes.parquet"))
>>> code_metadata.select("code", "code/vocab_index")
shape: (11, 2)
โโโโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโ
โ code โ code/vocab_index โ
โ --- โ --- โ
โ str โ u8 โ
โโโโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโก
โ ADMISSION//CARDIAC โ 1 โ
โ ADMISSION//ORTHOPEDIC โ 2 โ
โ ADMISSION//PULMONARY โ 3 โ
โ DISCHARGE โ 4 โ
โ DOB โ 5 โ
โ โฆ โ โฆ โ
โ EYE_COLOR//BROWN โ 7 โ
โ EYE_COLOR//HAZEL โ 8 โ
โ HEIGHT โ 9 โ
โ HR โ 10 โ
โ TEMP โ 11 โ
โโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโ
We can see these vocab indices being used if we look at some elements of the pytorch dataset. Note that some
elements of the returned dictionaries are
JointNestedRaggedTensorDict objects, so we'll define
a helper here that will use a helper from the associated library to help us pretty-print out outputs. Note
that we'll also reduce precision in the numeric values to make the output more readable.
>>> from nested_ragged_tensors.ragged_numpy import pprint_dense
>>> def print_element(el: dict):
... for k, v in el.items():
... print(f"{k} ({type(v).__name__}):")
... if k == "dynamic":
... pprint_dense(v.to_dense())
... else:
... print(v)
>>> print_element(pyd[2])
static_code (list):
[8, 9]
static_numeric_value (list):
[nan, -0.5438239574432373]
dynamic (JointNestedRaggedTensorDict):
code
[ 5 3 10 11 4]
.
numeric_value
[ nan nan -1.4474752 -0.34049404 nan]
.
time_delta_days
[ nan 1.17661045e+04 0.00000000e+00 0.00000000e+00
9.78703722e-02]
This example shows what the output looks like if we set the static data inclusion mode to "include". What if
we set it to "prepend" instead? To show this in a stable manner, we'll also use the seeded version of the
get item function, _seeded_getitem:
>>> pyd.config.static_inclusion_mode = "prepend"
>>> print_element(pyd._seeded_getitem(2, seed=0))
n_static_seq_els (int):
2
dynamic (JointNestedRaggedTensorDict):
code
[ 8 9 3 10 11]
.
numeric_value
[ nan -0.54382396 nan -1.4474752 -0.34049404]
.
time_delta_days
[ nan nan 11766.1045 0. 0. ]
>>> pyd.config.static_inclusion_mode = "include"
We can also look at what would be returned if we had included a task in the dataset:
```python
>>> print_element(pyd_with_task[0])
static_code (list):
[7, 9]
static_numeric_value (list):
[nan, 1.5770268440246582]
dynamic (JointNestedRaggedTensorDict):
code
[ 1 10 11 10 11]
.
numeric_value
[ nan -0.5697369 -1.2714673 -0.4375474 -1.1680276]
.
time_delta_days
[1.0726737e+04 0.0000000e+00 0.0000000e+00 4.8263888e-03 0.0000000e+00]
boolean_value (bool):
False
We can see in this case that the boolean_value field is included in the output, capturing the task label.
The contents of pyd[2] are stable, because index element 0, (68729, 0, 3), indicates the first subject has
a sequence of length 3 in the dataset and our max_seq_len is set to 5.
>>> print_element(pyd[2])
static_code (list):
[8, 9]
static_numeric_value (list):
[nan, -0.5438239574432373]
dynamic (JointNestedRaggedTensorDict):
code
[ 5 3 10 11 4]
.
numeric_value
[ nan nan -1.4474752 -0.34049404 nan]
.
time_delta_days
[ nan 1.17661045e+04 0.00000000e+00 0.00000000e+00
9.78703722e-02]
If we sampled a different subject, one with more than 5 events, the output we'd get would be dependent on the
config.seq_sampling_strategy option, and could be non-deterministic. By default, this is set to random, so
we'll get a random subset of length 5 each time. Here, so that this code is deterministic, we'll use
_seeded_getitem, an internal, seeded version of the __getitem__ call.
>>> print_element(pyd._seeded_getitem(1, seed=0))
static_code (list):
[6, 9]
static_numeric_value (list):
[nan, 0.06802856922149658]
dynamic (JointNestedRaggedTensorDict):
code
[10 11 10 11 10]
.
numeric_value
[-0.04626633 0.69391906 -0.30007038 0.79735875 -0.31064537]
.
time_delta_days
[0.01888889 0. 0.0084838 0. 0.01167824]
>>> print_element(pyd._seeded_getitem(1, seed=1))
static_code (list):
[6, 9]
static_numeric_value (list):
[nan, 0.06802856922149658]
dynamic (JointNestedRaggedTensorDict):
code
[10 11 10 11 10]
.
numeric_value
[ 0.03833488 0.79735875 0.33972722 0.7456389 -0.04626633]
.
time_delta_days
[0.00115741 0. 0.01373843 0. 0.01888889]
Of course, if we set seq_sampling_strategy to something other than "random", this non-determinism would
disappear:
>>> cfg_from_start = MEDSTorchDataConfig(
... tensorized_cohort_dir=tensorized_MEDS_dataset, max_seq_len=5, seq_sampling_strategy="from_start"
... )
>>> pyd_from_start = MEDSPytorchDataset(cfg_from_start, split="train")
>>> print_element(pyd_from_start[1])
static_code (list):
[6, 9]
static_numeric_value (list):
[nan, 0.06802856922149658]
dynamic (JointNestedRaggedTensorDict):
code
[ 5 1 10 11 10]
.
numeric_value
[ nan nan -0.23133166 0.79735875 0.03833488]
.
time_delta_days
[ nan 1.1688809e+04 0.0000000e+00 0.0000000e+00 1.1574074e-03]
>>> print_element(pyd_from_start[1])
static_code (list):
[6, 9]
static_numeric_value (list):
[nan, 0.06802856922149658]
dynamic (JointNestedRaggedTensorDict):
code
[ 5 1 10 11 10]
.
numeric_value
[ nan nan -0.23133166 0.79735875 0.03833488]
.
time_delta_days
[ nan 1.1688809e+04 0.0000000e+00 0.0000000e+00 1.1574074e-03]
Batches, Collation, and Dataloaders
We can also examine not just individual elements, but full batches, that we can access with the appropriate
collate function via the built in get_dataloader method. Here, we'll treat these outputs like
dictionaries, but they actually return dataclass objects that have some additional properties we can use to
access shapes and validate data. See the
API documentation
on the batch class for more information.
>>> batches = [batch for batch in pyd.get_dataloader(batch_size=2)]
>>> print(batches[1])
MEDSTorchBatch:
โ Mode: Subject-Measurement (SM)
โ Static data? โ
โ Labels? โ
โ
โ Shape:
โ โ Batch size: 2
โ โ Sequence length: 5
โ โ
โ โ All dynamic data: (2, 5)
โ โ Static data: (2, 2)
โ
โ Data:
โ โ Dynamic:
โ โ โ time_delta_days (torch.float32):
โ โ โ โ [[0.00e+00, 1.18e+04, ..., 0.00e+00, 9.79e-02],
โ โ โ โ [0.00e+00, 1.24e+04, ..., 0.00e+00, 4.64e-02]]
โ โ โ code (torch.int64):
โ โ โ โ [[ 5, 3, ..., 11, 4],
โ โ โ โ [ 5, 2, ..., 11, 4]]
โ โ โ numeric_value (torch.float32):
โ โ โ โ [[ 0.00, 0.00, ..., -0.34, 0.00],
โ โ โ โ [ 0.00, 0.00, ..., 0.85, 0.00]]
โ โ โ numeric_value_mask (torch.bool):
โ โ โ โ [[False, False, ..., True, False],
โ โ โ โ [False, False, ..., True, False]]
โ โ
โ โ Static:
โ โ โ static_code (torch.int64):
โ โ โ โ [[8, 9],
โ โ โ โ [8, 9]]
โ โ โ static_numeric_value (torch.float32):
โ โ โ โ [[ 0.00, -0.54],
โ โ โ โ [ 0.00, -1.10]]
โ โ โ static_numeric_value_mask (torch.bool):
โ โ โ โ [[False, True],
โ โ โ โ [False, True]]
>>> print(next(iter(pyd_with_task.get_dataloader(batch_size=2))))
MEDSTorchBatch:
โ Mode: Subject-Measurement (SM)
โ Static data? โ
โ Labels? โ
โ
โ Shape:
โ โ Batch size: 2
โ โ Sequence length: 5
โ โ
โ โ All dynamic data: (2, 5)
โ โ Static data: (2, 2)
โ โ Labels: torch.Size([2])
โ
โ Data:
โ โ Dynamic:
โ โ โ time_delta_days (torch.float32):
โ โ โ โ [[1.07e+04, 0.00e+00, ..., 4.83e-03, 0.00e+00],
โ โ โ โ [0.00e+00, 4.83e-03, ..., 2.55e-02, 0.00e+00]]
โ โ โ code (torch.int64):
โ โ โ โ [[ 1, 10, ..., 10, 11],
โ โ โ โ [11, 10, ..., 10, 11]]
โ โ โ numeric_value (torch.float32):
โ โ โ โ [[ 0.00e+00, -5.70e-01, ..., -4.38e-01, -1.17e+00],
โ โ โ โ [-1.27e+00, -4.38e-01, ..., 1.32e-03, -1.37e+00]]
โ โ โ numeric_value_mask (torch.bool):
โ โ โ โ [[False, True, ..., True, True],
โ โ โ โ [ True, True, ..., True, True]]
โ โ
โ โ Static:
โ โ โ static_code (torch.int64):
โ โ โ โ [[7, 9],
โ โ โ โ [7, 9]]
โ โ โ static_numeric_value (torch.float32):
โ โ โ โ [[0.00, 1.58],
โ โ โ โ [0.00, 1.58]]
โ โ โ static_numeric_value_mask (torch.bool):
โ โ โ โ [[False, True],
โ โ โ โ [False, True]]
โ โ
โ โ Labels:
โ โ โ boolean_value (torch.bool):
โ โ โ โ [False, True]
This is with the default static inclusion mode of "include", which means that the static data is included as
a separate entry in the batch. What about with the other two options, "omit" and "prepend"?
If we use "omit", we can see that the static data is omitted from the output:
>>> pyd_with_task.config.static_inclusion_mode = "omit"
>>> print(next(iter(pyd_with_task.get_dataloader(batch_size=2))))
MEDSTorchBatch:
โ Mode: Subject-Measurement (SM)
โ Static data? โ
โ Labels? โ
โ
โ Shape:
โ โ Batch size: 2
โ โ Sequence length: 5
โ โ
โ โ All dynamic data: (2, 5)
โ โ Labels: torch.Size([2])
โ
โ Data:
โ โ Dynamic:
โ โ โ time_delta_days (torch.float32):
โ โ โ โ [[1.07e+04, 0.00e+00, ..., 4.83e-03, 0.00e+00],
โ โ โ โ [0.00e+00, 4.83e-03, ..., 2.55e-02, 0.00e+00]]
โ โ โ code (torch.int64):
โ โ โ โ [[ 1, 10, ..., 10, 11],
โ โ โ โ [11, 10, ..., 10, 11]]
โ โ โ numeric_value (torch.float32):
โ โ โ โ [[ 0.00e+00, -5.70e-01, ..., -4.38e-01, -1.17e+00],
โ โ โ โ [-1.27e+00, -4.38e-01, ..., 1.32e-03, -1.37e+00]]
โ โ โ numeric_value_mask (torch.bool):
โ โ โ โ [[False, True, ..., True, True],
โ โ โ โ [ True, True, ..., True, True]]
โ โ
โ โ Labels:
โ โ โ boolean_value (torch.bool):
โ โ โ โ [False, True]
What if we use a static inclusion mode of "prepend"? We can see that the static data is prepended to the
dynamic data:
>>> pyd_with_task.config.static_inclusion_mode = "prepend"
>>> print(next(iter(pyd_with_task.get_dataloader(batch_size=2))))
MEDSTorchBatch:
โ Mode: Subject-Measurement (SM)
โ Static data? โ (prepended)
โ Labels? โ
โ
โ Shape:
โ โ Batch size: 2
โ โ Sequence length (static + dynamic): 5
โ โ
โ โ All [static; dynamic] data: (2, 5)
โ โ Labels: torch.Size([2])
โ
โ Data:
โ โ [Static; Dynamic]:
โ โ โ time_delta_days (torch.float32):
โ โ โ โ [[0.00, 0.00, ..., 0.00, 0.00],
โ โ โ โ [0.00, 0.00, ..., 0.03, 0.00]]
โ โ โ code (torch.int64):
โ โ โ โ [[ 7, 9, ..., 10, 11],
โ โ โ โ [ 7, 9, ..., 10, 11]]
โ โ โ numeric_value (torch.float32):
โ โ โ โ [[ 0.00e+00, 1.58e+00, ..., -4.38e-01, -1.17e+00],
โ โ โ โ [ 0.00e+00, 1.58e+00, ..., 1.32e-03, -1.37e+00]]
โ โ โ numeric_value_mask (torch.bool):
โ โ โ โ [[False, True, ..., True, True],
โ โ โ โ [False, True, ..., True, True]]
โ โ โ static_mask (torch.bool):
โ โ โ โ [[ True, True, ..., False, False],
โ โ โ โ [ True, True, ..., False, False]]
โ โ
โ โ Labels:
โ โ โ boolean_value (torch.bool):
โ โ โ โ [False, True]
>>> pyd_with_task.config.static_inclusion_mode = "include" # reset to default
Thus far, our examples have all worked with the default config object, which sets (among other things) the
default output to be at a measurement level, rather than an event level, by virtue of setting
batch_mode to SM. Let's see what happens if we change that:
>>> pyd.config.batch_mode = "SEM"
>>> print_element(pyd[2])
static_code (list):
[8, 9]
static_numeric_value (list):
[nan, -0.5438239574432373]
dynamic (JointNestedRaggedTensorDict):
time_delta_days
[ nan 1.17661045e+04 9.78703722e-02]
.
---
.
dim1/mask
[[ True False False]
[ True True True]
[ True False False]]
.
code
[[ 5 0 0]
[ 3 10 11]
[ 4 0 0]]
.
numeric_value
[[ nan 0. 0. ]
[ nan -1.4474752 -0.34049404]
[ nan 0. 0. ]]
>>> batches = [batch for batch in pyd.get_dataloader(batch_size=2)]
>>> print(batches[1])
MEDSTorchBatch:
โ Mode: Subject-Event-Measurement (SEM)
โ Static data? โ
โ Labels? โ
โ
โ Shape:
โ โ Batch size: 2
โ โ Sequence length: 3
โ โ Event length: 3
โ โ
โ โ Per-event data: (2, 3)
โ โ Per-measurement data: (2, 3, 3)
โ โ Static data: (2, 2)
โ
โ Data:
โ โ Event-level:
โ โ โ time_delta_days (torch.float32):
โ โ โ โ [[0.00e+00, 1.18e+04, 9.79e-02],
โ โ โ โ [0.00e+00, 1.24e+04, 4.64e-02]]
โ โ โ event_mask (torch.bool):
โ โ โ โ [[True, True, True],
โ โ โ โ [True, True, True]]
โ โ
โ โ Measurement-level:
โ โ โ code (torch.int64):
โ โ โ โ [[[ 5, 0, 0],
โ โ โ โ [ 3, 10, 11],
โ โ โ โ [ 4, 0, 0]],
โ โ โ โ [[ 5, 0, 0],
โ โ โ โ [ 2, 10, 11],
โ โ โ โ [ 4, 0, 0]]]
โ โ โ numeric_value (torch.float32):
โ โ โ โ [[[ 0.00, 0.00, 0.00],
โ โ โ โ [ 0.00, -1.45, -0.34],
โ โ โ โ [ 0.00, 0.00, 0.00]],
โ โ โ โ [[ 0.00, 0.00, 0.00],
โ โ โ โ [ 0.00, 3.00, 0.85],
โ โ โ โ [ 0.00, 0.00, 0.00]]]
โ โ โ numeric_value_mask (torch.bool):
โ โ โ โ [[[False, True, True],
โ โ โ โ [False, True, True],
โ โ โ โ [False, True, True]],
โ โ โ โ [[False, True, True],
โ โ โ โ [False, True, True],
โ โ โ โ [False, True, True]]]
โ โ
โ โ Static:
โ โ โ static_code (torch.int64):
โ โ โ โ [[8, 9],
โ โ โ โ [8, 9]]
โ โ โ static_numeric_value (torch.float32):
โ โ โ โ [[ 0.00, -0.54],
โ โ โ โ [ 0.00, -1.10]]
โ โ โ static_numeric_value_mask (torch.bool):
โ โ โ โ [[False, True],
โ โ โ โ [False, True]]
>>> pyd_with_task.config.batch_mode = "SEM"
>>> print(next(iter(pyd_with_task.get_dataloader(batch_size=2))))
MEDSTorchBatch:
โ Mode: Subject-Event-Measurement (SEM)
โ Static data? โ
โ Labels? โ
โ
โ Shape:
โ โ Batch size: 2
โ โ Sequence length: 4
โ โ Event length: 3
โ โ
โ โ Per-event data: (2, 4)
โ โ Per-measurement data: (2, 4, 3)
โ โ Static data: (2, 2)
โ โ Labels: torch.Size([2])
โ
โ Data:
โ โ Event-level:
โ โ โ time_delta_days (torch.float32):
โ โ โ โ [[0.00e+00, 1.07e+04, 4.83e-03, 0.00e+00],
โ โ โ โ [0.00e+00, 1.07e+04, 4.83e-03, 2.55e-02]]
โ โ โ event_mask (torch.bool):
โ โ โ โ [[ True, True, True, False],
โ โ โ โ [ True, True, True, True]]
โ โ
โ โ Measurement-level:
โ โ โ code (torch.int64):
โ โ โ โ [[[ 5, 0, 0],
โ โ โ โ [ 1, 10, 11],
โ โ โ โ [10, 11, 0],
โ โ โ โ [ 0, 0, 0]],
โ โ โ โ [[ 5, 0, 0],
โ โ โ โ [ 1, 10, 11],
โ โ โ โ [10, 11, 0],
โ โ โ โ [10, 11, 0]]]
โ โ โ numeric_value (torch.float32):
โ โ โ โ [[[ 0.00e+00, 0.00e+00, 0.00e+00],
โ โ โ โ [ 0.00e+00, -5.70e-01, -1.27e+00],
โ โ โ โ [-4.38e-01, -1.17e+00, 0.00e+00],
โ โ โ โ [ 0.00e+00, 0.00e+00, 0.00e+00]],
โ โ โ โ [[ 0.00e+00, 0.00e+00, 0.00e+00],
โ โ โ โ [ 0.00e+00, -5.70e-01, -1.27e+00],
โ โ โ โ [-4.38e-01, -1.17e+00, 0.00e+00],
โ โ โ โ [ 1.32e-03, -1.37e+00, 0.00e+00]]]
โ โ โ numeric_value_mask (torch.bool):
โ โ โ โ [[[False, True, True],
โ โ โ โ [False, True, True],
โ โ โ โ [ True, True, True],
โ โ โ โ [ True, True, True]],
โ โ โ โ [[False, True, True],
โ โ โ โ [False, True, True],
โ โ โ โ [ True, True, True],
โ โ โ โ [ True, True, True]]]
โ โ
โ โ Static:
โ โ โ static_code (torch.int64):
โ โ โ โ [[7, 9],
โ โ โ โ [7, 9]]
โ โ โ static_numeric_value (torch.float32):
โ โ โ โ [[0.00, 1.58],
โ โ โ โ [0.00, 1.58]]
โ โ โ static_numeric_value_mask (torch.bool):
โ โ โ โ [[False, True],
โ โ โ โ [False, True]]
โ โ
โ โ Labels:
โ โ โ boolean_value (torch.bool):
โ โ โ โ [False, True]
>>> pyd_with_task.config.static_inclusion_mode = "omit"
>>> print(next(iter(pyd_with_task.get_dataloader(batch_size=2))))
MEDSTorchBatch:
โ Mode: Subject-Event-Measurement (SEM)
โ Static data? โ
โ Labels? โ
โ
โ Shape:
โ โ Batch size: 2
โ โ Sequence length: 4
โ โ Event length: 3
โ โ
โ โ Per-event data: (2, 4)
โ โ Per-measurement data: (2, 4, 3)
โ โ Labels: torch.Size([2])
โ
โ Data:
โ โ Event-level:
โ โ โ time_delta_days (torch.float32):
โ โ โ โ [[0.00e+00, 1.07e+04, 4.83e-03, 0.00e+00],
โ โ โ โ [0.00e+00, 1.07e+04, 4.83e-03, 2.55e-02]]
โ โ โ event_mask (torch.bool):
โ โ โ โ [[ True, True, True, False],
โ โ โ โ [ True, True, True, True]]
โ โ
โ โ Measurement-level:
โ โ โ code (torch.int64):
โ โ โ โ [[[ 5, 0, 0],
โ โ โ โ [ 1, 10, 11],
โ โ โ โ [10, 11, 0],
โ โ โ โ [ 0, 0, 0]],
โ โ โ โ [[ 5, 0, 0],
โ โ โ โ [ 1, 10, 11],
โ โ โ โ [10, 11, 0],
โ โ โ โ [10, 11, 0]]]
โ โ โ numeric_value (torch.float32):
โ โ โ โ [[[ 0.00e+00, 0.00e+00, 0.00e+00],
โ โ โ โ [ 0.00e+00, -5.70e-01, -1.27e+00],
โ โ โ โ [-4.38e-01, -1.17e+00, 0.00e+00],
โ โ โ โ [ 0.00e+00, 0.00e+00, 0.00e+00]],
โ โ โ โ [[ 0.00e+00, 0.00e+00, 0.00e+00],
โ โ โ โ [ 0.00e+00, -5.70e-01, -1.27e+00],
โ โ โ โ [-4.38e-01, -1.17e+00, 0.00e+00],
โ โ โ โ [ 1.32e-03, -1.37e+00, 0.00e+00]]]
โ โ โ numeric_value_mask (torch.bool):
โ โ โ โ [[[False, True, True],
โ โ โ โ [False, True, True],
โ โ โ โ [ True, True, True],
โ โ โ โ [ True, True, True]],
โ โ โ โ [[False, True, True],
โ โ โ โ [False, True, True],
โ โ โ โ [ True, True, True],
โ โ โ โ [ True, True, True]]]
โ โ
โ โ Labels:
โ โ โ boolean_value (torch.bool):
โ โ โ โ [False, True]
>>> pyd_with_task.config.static_inclusion_mode = "prepend"
>>> print(next(iter(pyd_with_task.get_dataloader(batch_size=2))))
MEDSTorchBatch:
โ Mode: Subject-Event-Measurement (SEM)
โ Static data? โ (prepended)
โ Labels? โ
โ
โ Shape:
โ โ Batch size: 2
โ โ Sequence length (static + dynamic): 5
โ โ Event length: 3
โ โ
โ โ Per-event data: (2, 5)
โ โ Per-measurement data: (2, 5, 3)
โ โ Labels: torch.Size([2])
โ
โ Data:
โ โ Event-level:
โ โ โ time_delta_days (torch.float32):
โ โ โ โ [[0.00, 0.00, ..., 0.00, 0.00],
โ โ โ โ [0.00, 0.00, ..., 0.00, 0.03]]
โ โ โ event_mask (torch.bool):
โ โ โ โ [[ True, True, ..., True, False],
โ โ โ โ [ True, True, ..., True, True]]
โ โ โ static_mask (torch.bool):
โ โ โ โ [[ True, False, ..., False, False],
โ โ โ โ [ True, False, ..., False, False]]
โ โ
โ โ Measurement-level:
โ โ โ code (torch.int64):
โ โ โ โ [[[ 7, 9, 0],
โ โ โ โ [ 5, 0, 0],
โ โ โ โ ...,
โ โ โ โ [10, 11, 0],
โ โ โ โ [ 0, 0, 0]],
โ โ โ โ [[ 7, 9, 0],
โ โ โ โ [ 5, 0, 0],
โ โ โ โ ...,
โ โ โ โ [10, 11, 0],
โ โ โ โ [10, 11, 0]]]
โ โ โ numeric_value (torch.float32):
โ โ โ โ [[[ 0.00e+00, 1.58e+00, 0.00e+00],
โ โ โ โ [ 0.00e+00, 0.00e+00, 0.00e+00],
โ โ โ โ ...,
โ โ โ โ [-4.38e-01, -1.17e+00, 0.00e+00],
โ โ โ โ [ 0.00e+00, 0.00e+00, 0.00e+00]],
โ โ โ โ [[ 0.00e+00, 1.58e+00, 0.00e+00],
โ โ โ โ [ 0.00e+00, 0.00e+00, 0.00e+00],
โ โ โ โ ...,
โ โ โ โ [-4.38e-01, -1.17e+00, 0.00e+00],
โ โ โ โ [ 1.32e-03, -1.37e+00, 0.00e+00]]]
โ โ โ numeric_value_mask (torch.bool):
โ โ โ โ [[[False, True, True],
โ โ โ โ [False, True, True],
โ โ โ โ ...,
โ โ โ โ [ True, True, True],
โ โ โ โ [ True, True, True]],
โ โ โ โ [[False, True, True],
โ โ โ โ [False, True, True],
โ โ โ โ ...,
โ โ โ โ [ True, True, True],
โ โ โ โ [ True, True, True]]]
โ โ
โ โ Labels:
โ โ โ boolean_value (torch.bool):
โ โ โ โ [False, True]
Data Tensorization and Pre-processing Details
Full documentation for the preprocessing pipeline can be found here
The MTD_preprocess command leverages hydra to manage the configuration and running
via the command line. You can see the available options by running the command with the --help flag:
== MTD_preprocess ==
MTD_preprocess is a command line tool for pre-processing MEDS data for use with meds_torchdata.
== Config ==
This is the config generated for this run:
MEDS_dataset_dir: ???
output_dir: ???
stage_runner_fp: null
do_overwrite: false
do_reshard: false
log_dir: ${output_dir}/.logs
You can override everything using the hydra `key=value` syntax; for example:
MTD_preprocess MEDS_dataset_dir=/path/to/dataset output_dir=/path/to/output do_overwrite=True
The MTD_preprocess command runs the following pre-processing stages:
fit_normalization: Fitting necessary parameters for normalization from the raw data (e.g., the mean and standard deviation of thenumeric_valuefield).fit_vocabulary_indices: Assigning unique vocabulary indices to each uniquecodein the data so that they can be transformed to numerical indices for tensorization.normalization: Normalizing the data using the parameters fit in thefit_normalizationstage to have a mean of 0 and a standard deviation of 1.tokenization: Producing the schema files necessary for the tensorization stage.tensorization: Producing the nested ragged tensor views of the data.
[!NOTE] If you would like additional normalization options to be supported, please comment on the upstream issue in MEDS-Transforms, and file an issue here to capture supporting additional options cleanly going forward.
[!NOTE] You should perform any additional, model specific pre-processing on the data prior to running the
MTD_preprocesscommand for your specific use-case. Indeed, if you wish to perform additional pre-processing, such as
- Dropping numeric values entirely and converting to quantile-modified codes.
- Drop infrequent codes or aggregate codes into higher-order categories.
- Restrict subjects to a specific time-window
- Drop subjects with infrequent values
- Occlude outlier numeric values
- etc. You should perform these steps on the raw MEDS data prior to running the tensorization command. This ensures that the data is modified as you desire in an efficient, transparent way and that the tensorization step works with data in its final format to avoid any issues with discrepancies in code vocabulary, etc.
Advanced features
You can also use this package natively with Hydra in modeling applications by adding the
meds_torchdata.MEDSTorchDataConfig to the Hydra config store. This will allow you to use it as though it
were a fully defined .yaml configuration file in your application configuration. To do this, you simply need
to run MEDSTorchDataConfig.add_to_config_store() in your application, specifying the group name in which you
plan to use the config in your application.
E.g., if you have a config file like this:
dataset:
_target_: meds_torchdata.MEDSPytorchDataset
config: MEDSTorchDataConfig
Then in your main application, prior to @hydra.main, you can run:
from meds_torchdata.config import MEDSTorchDataConfig
MEDSTorchDataConfig.add_to_config_store("dataset/config")
This will add the MEDSTorchDataConfig to the Hydra config store in the nested dataset/config group, which
will allow you to override its parameters from the command line and instantiate it into object form natively.
Testing Models that Use this Package
If you use this package to build your model, we also expose some pytest fixtures that can be used to test your
models. These fixtures are designed to be used with the pytest testing framework. These fixtures are similar
to the four fixtures we used above in the Examples and Detailed Usage section.
Namely, they are:
tensorized_MEDS_datasetfixture that points to a Path containing the tensorized and schema files for thesimple_static_MEDSdataset.tensorized_MEDS_dataset_with_taskfixture that points to a tuple containing:- A Path containing the tensorized and schema files for the
simple_static_MEDS_dataset_with_taskdataset - A Path pointing to the root task directory for the dataset
- The specific task name for the dataset. Task label files will be stored in a subdir of the root task directory with this name.
- A Path containing the tensorized and schema files for the
sample_pytorch_dataset: This will yield aMEDSPytorchDatasetobject built using thetensorized_MEDS_datasetdata, without a downstream task.sample_pytorch_dataset_with_task: This will yield aMEDSPytorchDatasetobject built using thetensorized_MEDS_dataset_with_taskdata, with the associated downstream task.
You can rely on these fixtures to test your model in the normal way, directly having your model train using input batches derived from the sample datasets.
Performance
See https://mmcdermott.github.io/meds-torch-data/dev/bench/ for performance benchmarks for all commits in this repository. See here for the benchmarking script. Note that these benchmarks are likely to change over time so should be judged relative to the content of the associated commits, not in absolute terms (e.g., we are likely to benchmark on more or more complex synthetic data, etc.).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file meds_torch_data-0.6.4.tar.gz.
File metadata
- Download URL: meds_torch_data-0.6.4.tar.gz
- Upload date:
- Size: 211.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
30c72b979979d2080054178cc7c3a7adb44ed39a984347f3746df7d482655e51
|
|
| MD5 |
8ba35e30671ee340963133aa578a9188
|
|
| BLAKE2b-256 |
e643a143cdcec5a53909c08f6489e9367921a0485d60367d58f286ebf6aaa7ed
|
Provenance
The following attestation bundles were made for meds_torch_data-0.6.4.tar.gz:
Publisher:
python-build.yaml on mmcdermott/meds-torch-data
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
meds_torch_data-0.6.4.tar.gz -
Subject digest:
30c72b979979d2080054178cc7c3a7adb44ed39a984347f3746df7d482655e51 - Sigstore transparency entry: 514958282
- Sigstore integration time:
-
Permalink:
mmcdermott/meds-torch-data@b00b4f8b6ff0aa05e761e9333944f6a6299e0feb -
Branch / Tag:
refs/tags/0.6.4 - Owner: https://github.com/mmcdermott
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-build.yaml@b00b4f8b6ff0aa05e761e9333944f6a6299e0feb -
Trigger Event:
push
-
Statement type:
File details
Details for the file meds_torch_data-0.6.4-py3-none-any.whl.
File metadata
- Download URL: meds_torch_data-0.6.4-py3-none-any.whl
- Upload date:
- Size: 57.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a50fd7ce775d008abb64e87ede5b4f059ec8e0de024ab5dc594778429d2a73d7
|
|
| MD5 |
faa3beccb3f522c7cae74bed29f90c2a
|
|
| BLAKE2b-256 |
f9bce35c85ac65d34c6d43d68cb680f473006fb5862864bbf3f1a9d23c6fb1e1
|
Provenance
The following attestation bundles were made for meds_torch_data-0.6.4-py3-none-any.whl:
Publisher:
python-build.yaml on mmcdermott/meds-torch-data
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
meds_torch_data-0.6.4-py3-none-any.whl -
Subject digest:
a50fd7ce775d008abb64e87ede5b4f059ec8e0de024ab5dc594778429d2a73d7 - Sigstore transparency entry: 514958329
- Sigstore integration time:
-
Permalink:
mmcdermott/meds-torch-data@b00b4f8b6ff0aa05e761e9333944f6a6299e0feb -
Branch / Tag:
refs/tags/0.6.4 - Owner: https://github.com/mmcdermott
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-build.yaml@b00b4f8b6ff0aa05e761e9333944f6a6299e0feb -
Trigger Event:
push
-
Statement type: