An efficient, flexible PyTorch dataset class for MEDS data.
Project description
MEDS TorchData: A PyTorch Dataset Class for MEDS Datasets
๐ Quick Start
Step 1: Install
pip install meds-torch-data
Step 2: Data Tensorization:
[!WARNING] If your dataset is not sharded by split, you need to run a reshard to split stage first! You can enable this by adding the
do_reshard=Trueargument to the command below.
MEDS_tensorize input_dir=... output_dir=...
Step 3: Use the dataset:
In your code, simply:
from meds_torchdata import MEDSPytorchDataset
pyd = MEDSPytorchDataset(...)
To see how this works, let's look at some examples. These examples will be powered by some synthetic data defined as "fixtures" in this package's pytest stack; namely, we'll use the following fixtures:
simple_static_MEDS: This will point to a Path containing a simple MEDS dataset.simple_static_MEDS_dataset_with_task: This will point to a Path containing a simple MEDS dataset with a boolean-value task defined. The core data is the same between both thesimple_static_MEDSand this dataset, but the latter has a task defined.tensorized_MEDS_datasetfixture that points to a Path containing the tensorized and schema files for thesimple_static_MEDSdataset.tensorized_MEDS_dataset_with_taskfixture that points to a Path containing the tensorized and schema files for thesimple_static_MEDS_dataset_with_taskdataset.
You can find these in either the conftest.py file for this repository or the
meds_testing_helpers package, which
this package leverages for testing.
To start, let's take a look at this syntehtic data. It is sharded by split, and we'll look at the train split first, which has two shards (we convert to polars just for prettier printing). It has four subjects across the two shards.
>>> import polars as pl
>>> from meds_testing_helpers.dataset import MEDSDataset
>>> D = MEDSDataset(root_dir=simple_static_MEDS)
>>> train_0 = pl.from_arrow(D.data_shards["train/0"])
>>> train_0
shape: (30, 4)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โ subject_id โ time โ code โ numeric_value โ
โ --- โ --- โ --- โ --- โ
โ i64 โ datetime[ฮผs] โ str โ f32 โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโก
โ 239684 โ null โ EYE_COLOR//BROWN โ null โ
โ 239684 โ null โ HEIGHT โ 175.271118 โ
โ 239684 โ 1980-12-28 00:00:00 โ DOB โ null โ
โ 239684 โ 2010-05-11 17:41:51 โ ADMISSION//CARDIAC โ null โ
โ 239684 โ 2010-05-11 17:41:51 โ HR โ 102.599998 โ
โ โฆ โ โฆ โ โฆ โ โฆ โ
โ 1195293 โ 2010-06-20 20:24:44 โ HR โ 107.699997 โ
โ 1195293 โ 2010-06-20 20:24:44 โ TEMP โ 100.0 โ
โ 1195293 โ 2010-06-20 20:41:33 โ HR โ 107.5 โ
โ 1195293 โ 2010-06-20 20:41:33 โ TEMP โ 100.400002 โ
โ 1195293 โ 2010-06-20 20:50:04 โ DISCHARGE โ null โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโ
>>> train_1 = pl.from_arrow(D.data_shards["train/1"])
>>> train_1
shape: (14, 4)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โ subject_id โ time โ code โ numeric_value โ
โ --- โ --- โ --- โ --- โ
โ i64 โ datetime[ฮผs] โ str โ f32 โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโก
โ 68729 โ null โ EYE_COLOR//HAZEL โ null โ
โ 68729 โ null โ HEIGHT โ 160.395309 โ
โ 68729 โ 1978-03-09 00:00:00 โ DOB โ null โ
โ 68729 โ 2010-05-26 02:30:56 โ ADMISSION//PULMONARY โ null โ
โ 68729 โ 2010-05-26 02:30:56 โ HR โ 86.0 โ
โ โฆ โ โฆ โ โฆ โ โฆ โ
โ 814703 โ 1976-03-28 00:00:00 โ DOB โ null โ
โ 814703 โ 2010-02-05 05:55:39 โ ADMISSION//ORTHOPEDIC โ null โ
โ 814703 โ 2010-02-05 05:55:39 โ HR โ 170.199997 โ
โ 814703 โ 2010-02-05 05:55:39 โ TEMP โ 100.099998 โ
โ 814703 โ 2010-02-05 07:02:30 โ DISCHARGE โ null โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโ
>>> sorted(set(train_0["subject_id"].unique()) | set(train_1["subject_id"].unique()))
[68729, 239684, 814703, 1195293]
Given this data, when we build a PyTorch dataset from it for training, with no task specified, the
length will be four, as it will correspond to each of the four subjects in the train split. The index variable
contains the list of subject IDs and the end of the allowed region of reading for the dataset. We can also see
it in dataframe format via the schema_df:
>>> from meds_torchdata.pytorch_dataset import MEDSTorchDataConfig, MEDSPytorchDataset
>>> cfg = MEDSTorchDataConfig(tensorized_cohort_dir=tensorized_MEDS_dataset, max_seq_len=5)
>>> pyd = MEDSPytorchDataset(cfg, split="train")
>>> len(pyd)
4
>>> pyd.index
[(68729, 3), (814703, 3), (239684, 6), (1195293, 8)]
>>> pyd.schema_df
shape: (4, 2)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโ
โ subject_id โ end_event_index โ
โ --- โ --- โ
โ i64 โ u32 โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโก
โ 68729 โ 3 โ
โ 814703 โ 3 โ
โ 239684 โ 6 โ
โ 1195293 โ 8 โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโ
Note the index is in terms of event indices, not measurement indices -- meaning it is the index of the
unique timestamp corresponding to the start and end of each subject's data; not the unique measurement. We can
validate that against the raw data. To do so, we'll define the simple helper function get_event_bounds that
will just group by the subject_id and time columns, and then calculate the event index for each subject
and show us the min and max such index, per-subject.
>>> def get_event_indices(df: pl.DataFrame) -> pl.DataFrame:
... return (
... df
... .group_by("subject_id", "time", maintain_order=True).agg(pl.len().alias("n_measurements"))
... .with_row_index()
... .select(
... "subject_id", "time",
... (pl.col("index") - pl.col("index").min().over("subject_id")).alias("event_idx"),
... "n_measurements",
... )
... )
>>> def get_event_bounds(df: pl.DataFrame) -> pl.DataFrame:
... return (
... get_event_indices(df)
... .with_columns(
... pl.col("event_idx").max().over("subject_id").alias("max_event_idx")
... )
... .filter((pl.col("event_idx") == 0) | (pl.col("event_idx") == pl.col("max_event_idx")))
... .select("subject_id", "event_idx", "time")
... )
>>> get_event_bounds(train_1)
shape: (4, 3)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโ
โ subject_id โ event_idx โ time โ
โ --- โ --- โ --- โ
โ i64 โ u32 โ datetime[ฮผs] โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโก
โ 68729 โ 0 โ null โ
โ 68729 โ 3 โ 2010-05-26 04:51:52 โ
โ 814703 โ 0 โ null โ
โ 814703 โ 3 โ 2010-02-05 07:02:30 โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโ
>>> get_event_bounds(train_0)
shape: (4, 3)
โโโโโโโโโโโโโโฌโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโ
โ subject_id โ event_idx โ time โ
โ --- โ --- โ --- โ
โ i64 โ u32 โ datetime[ฮผs] โ
โโโโโโโโโโโโโโชโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโก
โ 239684 โ 0 โ null โ
โ 239684 โ 6 โ 2010-05-11 19:27:19 โ
โ 1195293 โ 0 โ null โ
โ 1195293 โ 8 โ 2010-06-20 20:50:04 โ
โโโโโโโโโโโโโโดโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโ
While the raw data has codes as strings, naturally, when embedded in the pytorch dataset, they'll get converted to integers. This happens during the forementioned tensorization step. We can see how the codes are mapped to integers by looking at the output code metadata of that step:
>>> code_metadata = pl.read_parquet(tensorized_MEDS_dataset.joinpath("metadata/codes.parquet"))
>>> code_metadata.select("code", "code/vocab_index")
shape: (11, 2)
โโโโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโ
โ code โ code/vocab_index โ
โ --- โ --- โ
โ str โ u8 โ
โโโโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโก
โ ADMISSION//CARDIAC โ 1 โ
โ ADMISSION//ORTHOPEDIC โ 2 โ
โ ADMISSION//PULMONARY โ 3 โ
โ DISCHARGE โ 4 โ
โ DOB โ 5 โ
โ โฆ โ โฆ โ
โ EYE_COLOR//BROWN โ 7 โ
โ EYE_COLOR//HAZEL โ 8 โ
โ HEIGHT โ 9 โ
โ HR โ 10 โ
โ TEMP โ 11 โ
โโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโ
We can see these vocab indices being used if we look at some elements of the pytorch dataset. Note that some
elements of the returned dictionaries are
JointNestedRaggedTensorDict objects, so we'll define
a helper here that will use a helper from the associated library to help us pretty-print out outputs. Note
that we'll also reduce precision in the numeric values to make the output more readable.
>>> from nested_ragged_tensors.ragged_numpy import pprint_dense
>>> def print_element(el: dict):
... for k, v in el.items():
... print(f"{k} ({type(v).__name__}):")
... if k == "dynamic":
... pprint_dense(v.to_dense())
... else:
... print(v)
>>> print_element(pyd[0])
static_code (list):
[8, 9]
static_numeric_value (list):
[nan, -0.5438239574432373]
dynamic (JointNestedRaggedTensorDict):
code
[ 5 3 10 11 4]
.
numeric_value
[ nan nan -1.4474752 -0.34049404 nan]
.
time_delta_days
[ nan 1.17661045e+04 0.00000000e+00 0.00000000e+00
9.78703722e-02]
The contents of pyd[0] are stable, because index element 0, (68729, 0, 3), indicates the first subject has
a sequence of length 3 in the dataset and our max_seq_len is set to 5.
>>> print_element(pyd[0])
static_code (list):
[8, 9]
static_numeric_value (list):
[nan, -0.5438239574432373]
dynamic (JointNestedRaggedTensorDict):
code
[ 5 3 10 11 4]
.
numeric_value
[ nan nan -1.4474752 -0.34049404 nan]
.
time_delta_days
[ nan 1.17661045e+04 0.00000000e+00 0.00000000e+00
9.78703722e-02]
If we sampled a different subject, one with more than 5 events, the output we'd get would be dependent on the
config.seq_sampling_strategy option, and could be non-deterministic. By default, this is set to random, so
we'll get a random subset of length 5 each time. Here, so that this code is deterministic, we'll use the
internal seeded version of the getitem call, which just allows to add a seed onto the normal getitem call.
>>> print_element(pyd._seeded_getitem(3, seed=0))
static_code (list):
[6, 9]
static_numeric_value (list):
[nan, 0.06802856922149658]
dynamic (JointNestedRaggedTensorDict):
code
[10 11 10 11 10]
.
numeric_value
[-0.04626633 0.69391906 -0.30007038 0.79735875 -0.31064537]
.
time_delta_days
[0.01888889 0. 0.0084838 0. 0.01167824]
>>> print_element(pyd._seeded_getitem(3, seed=1))
static_code (list):
[6, 9]
static_numeric_value (list):
[nan, 0.06802856922149658]
dynamic (JointNestedRaggedTensorDict):
code
[10 11 10 11 10]
.
numeric_value
[ 0.03833488 0.79735875 0.33972722 0.7456389 -0.04626633]
.
time_delta_days
[0.00115741 0. 0.01373843 0. 0.01888889]
We can also examine not just individual elements, but full batches, that we can access with the appropriate
collate function via the built in get_dataloader method:
>>> print_element(next(iter(pyd.get_dataloader(batch_size=2))))
time_delta_days (Tensor):
tensor([[0.0000e+00, 1.1766e+04, 0.0000e+00, 0.0000e+00, 9.7870e-02],
[0.0000e+00, 1.2367e+04, 0.0000e+00, 0.0000e+00, 4.6424e-02]])
code (Tensor):
tensor([[ 5, 3, 10, 11, 4],
[ 5, 2, 10, 11, 4]])
mask (Tensor):
tensor([[True, True, True, True, True],
[True, True, True, True, True]])
numeric_value (Tensor):
tensor([[ 0.0000, 0.0000, -1.4475, -0.3405, 0.0000],
[ 0.0000, 0.0000, 3.0047, 0.8491, 0.0000]])
numeric_value_mask (Tensor):
tensor([[False, False, True, True, False],
[False, False, True, True, False]])
static_code (Tensor):
tensor([[8, 9],
[8, 9]])
static_numeric_value (Tensor):
tensor([[ 0.0000, -0.5438],
[ 0.0000, -1.1012]])
static_numeric_value_mask (Tensor):
tensor([[False, True],
[False, True]])
Thus far, our examples have all worked with the default config object, which sets (among other things) the
default output to be at a measurement level, rather than an event level, by virtue of setting
do_flatten_tensors to True. Let's see what happens if we change that:
>>> pyd.config.do_flatten_tensors = False
>>> print_element(pyd[0])
static_code (list):
[8, 9]
static_numeric_value (list):
[nan, -0.5438239574432373]
dynamic (JointNestedRaggedTensorDict):
time_delta_days
[ nan 1.17661045e+04 9.78703722e-02]
.
---
.
dim1/mask
[[ True False False]
[ True True True]
[ True False False]]
.
code
[[ 5 0 0]
[ 3 10 11]
[ 4 0 0]]
.
numeric_value
[[ nan 0. 0. ]
[ nan -1.4474752 -0.34049404]
[ nan 0. 0. ]]
>>> print_element(next(iter(pyd.get_dataloader(batch_size=2))))
time_delta_days (Tensor):
tensor([[0.0000e+00, 1.1766e+04, 9.7870e-02],
[0.0000e+00, 1.2367e+04, 4.6424e-02]])
code (Tensor):
tensor([[[ 5, 0, 0],
[ 3, 10, 11],
[ 4, 0, 0]],
<BLANKLINE>
[[ 5, 0, 0],
[ 2, 10, 11],
[ 4, 0, 0]]])
mask (Tensor):
tensor([[True, True, True],
[True, True, True]])
numeric_value (Tensor):
tensor([[[ 0.0000, 0.0000, 0.0000],
[ 0.0000, -1.4475, -0.3405],
[ 0.0000, 0.0000, 0.0000]],
<BLANKLINE>
[[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 3.0047, 0.8491],
[ 0.0000, 0.0000, 0.0000]]])
numeric_value_mask (Tensor):
tensor([[[False, True, True],
[False, True, True],
[False, True, True]],
<BLANKLINE>
[[False, True, True],
[False, True, True],
[False, True, True]]])
static_code (Tensor):
tensor([[8, 9],
[8, 9]])
static_numeric_value (Tensor):
tensor([[ 0.0000, -0.5438],
[ 0.0000, -1.1012]])
static_numeric_value_mask (Tensor):
tensor([[False, True],
[False, True]])
๐ Documentation
Design Principles
A good PyTorch dataset class should:
- Be easy to use
- Have a minimal, constant resource footprint (memory, CPU, start-up time) during model training and inference, regardless of the overall dataset size.
- Perform as much work as possible in static, re-usable dataset pre-processing, rather than upon construction or in the getitem method.
- Induce effectively negligible computational overhead in the getitem method relative to model training.
- Be easily configurable, with a simple, consistent API, and cover the most common use-cases.
- Encourage efficient use of GPU resources in the resulting batches.
- Should be comprehensively documented, tested, and benchmarked for performance implications so users can use it reliably and effectively.
To achieve this, MEDS TorchData leverages the following design principles:
- Lazy Loading: Data is loaded only when needed, and only the data needed for the current batch is loaded.
- Efficient Loading: Data is loaded efficiently leveraging the HuggingFace Safetensors library for raw IO through the nested, ragged interface encoded in the Nested Ragged Tensors library.
- Configurable, Transparent Pre-processing: Mandatory data pre-processing prior to effective use in this library is managed through a simple MEDS-Transforms pipeline which can be run on any MEDS dataset, after any model-specific pre-processing, via a transparent configuration file.
- Continuous Integration: The library is continuously tested and benchmarked for performance implications, and the results are available to users.
API and Usage
Data Tensorization and Pre-processing
The MEDS_tensorize command-line utility is used to convert the input MEDS data into a format that can be
loaded into the PyTorch dataset class contained in this package. This command performs a very simple series of
steps:
- Normalize the data into an appropriate, numerical format, including:
- Assigning each unique
codein the data a unique integer index and converting the codes to those integer indices. - Normalizing thenumeric_valuefield to have a mean of 0 and a standard deviation of 1. If you would like additional normalization options supported, such as min-max normalization, please file a GitHub issue. - Produce a set of static, "schema" files that contain the unique time-points of each subjects' events as well as their static measurements.
- Produce a set of
JointNestedRaggedTensorDictobject files that contain each subjects' dynamic measurements in the form of nested, ragged tensors that can be efficiently loaded via the associated package
These are the only three steps this pipeline performs. Note, however, that this does not mean you can't or shouldn't perform additional, model specific pre-processing on the data prior to running the tensorization command for your specific use-case. Indeed, if you wish to perform additional pre-processing, such as
- Dropping numeric values entirely and converting to quantile-modified codes.
- Drop infrequent codes or aggregate codes into higher-order categories.
- Restrict subjects to a specific time-window
- Drop subjects with infrequent values
- Occlude outlier numeric values
- etc. You should perform these steps on the raw MEDS data prior to running the tensorization command. This ensures that the data is modified as you desire in an efficient, transparent way and that the tensorization step works with data in its final format to avoid any issues with discrepancies in code vocabulary, etc.
Dataset Class
Once the data has been tensorized, you can use the MEDSPytorchDataset class to load the data into a PyTorch
dataset suitable to begin modeling! This dataset class takes a configuration object as input, with the
following fields:
Performance
See https://mmcdermott.github.io/meds-torch-data/dev/bench/ for performance benchmarks for all commits in this repository. See here for the benchmarking script. Note that these benchmarks are likely to change over time so should be judged relative to the content of the associated commits, not in absolute terms (e.g., we are likely to benchmark on more or more complex synthetic data, etc.).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file meds_torch_data-0.3.tar.gz.
File metadata
- Download URL: meds_torch_data-0.3.tar.gz
- Upload date:
- Size: 53.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8d7de60a2d24f12a990bbfb4154053efa71cd19cdd05f5f3590c69c61a08944a
|
|
| MD5 |
85a89a5d6341042ef692c45cb81825bf
|
|
| BLAKE2b-256 |
c8f1cf48bd669eceede792a7bb87951a7c89496be8ccae386682f7df0b538beb
|
Provenance
The following attestation bundles were made for meds_torch_data-0.3.tar.gz:
Publisher:
python-build.yaml on mmcdermott/meds-torch-data
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
meds_torch_data-0.3.tar.gz -
Subject digest:
8d7de60a2d24f12a990bbfb4154053efa71cd19cdd05f5f3590c69c61a08944a - Sigstore transparency entry: 184399707
- Sigstore integration time:
-
Permalink:
mmcdermott/meds-torch-data@473f33b4e1780d58cef6006ce6e45228ee384f20 -
Branch / Tag:
refs/tags/0.3 - Owner: https://github.com/mmcdermott
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-build.yaml@473f33b4e1780d58cef6006ce6e45228ee384f20 -
Trigger Event:
push
-
Statement type:
File details
Details for the file meds_torch_data-0.3-py3-none-any.whl.
File metadata
- Download URL: meds_torch_data-0.3-py3-none-any.whl
- Upload date:
- Size: 30.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e2a1c01894f2ebcd8229a8ab4deb88154e2798598b25f9ae43bbf309b8fe1e8b
|
|
| MD5 |
da9ccb8f078e34883a32bdb9aaa376a2
|
|
| BLAKE2b-256 |
5799a6e6ce8aab3a9dd95423441b51b8f48e59c7769d1bc566f4d3de13b3aa9c
|
Provenance
The following attestation bundles were made for meds_torch_data-0.3-py3-none-any.whl:
Publisher:
python-build.yaml on mmcdermott/meds-torch-data
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
meds_torch_data-0.3-py3-none-any.whl -
Subject digest:
e2a1c01894f2ebcd8229a8ab4deb88154e2798598b25f9ae43bbf309b8fe1e8b - Sigstore transparency entry: 184399713
- Sigstore integration time:
-
Permalink:
mmcdermott/meds-torch-data@473f33b4e1780d58cef6006ce6e45228ee384f20 -
Branch / Tag:
refs/tags/0.3 - Owner: https://github.com/mmcdermott
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-build.yaml@473f33b4e1780d58cef6006ce6e45228ee384f20 -
Trigger Event:
push
-
Statement type: