Utilities for efficiently working with, saving, and loading, collections of connected nested ragged tensors in PyTorch
Project description
Nested Ragged Tensors
Install with pip install nested-ragged-tensors
This package contains utilities for efficiently working with joint collections of nested ragged tensors.
Some terminology:
- A "tensor" is a multi-dimensional array of values. Typically, tensors are stored as dense arrays, such
that, when viewed as lists of lists of ... lists, all lists at a given level of nesting have the same
length. E.g., a 2D tensor might be stored as
[[1, 2], [3, 4], [5, 6]]
. - A "ragged tensor" is a tensor where the lists at a given level of nesting do not all have the same
length. E.g., a ragged 2D tensor might be stored as
[[1, 2], [3], [4, 5, 6]]
. - A "nested ragged tensor" is a set of ragged tensors whose sub-list lengths are hierarchically connected. E.g., we might have a collection of patients, each of whom has a collection of clinic visits, each of which has a collection of laboratory test codes. In this case, while the number of visits and codes per visit are ragged, the codes are still "per-visit", and thus their second level of "raggedness" is identical to the first level of "raggedness" for the visits.
- A collection of "joint nested ragged tensors" is a set of nested ragged tensors that are further connected beyond the (shared) hierarchical connections, such that multiple tensors at a given level of the hierarchy further share the same lengths of their sublists. E.g., we might have a collection of patients, each of whom has a collection of clinic visits, each of which has a collection of laboratory test codes and numeric values. In this case, we have a collection of ragged tensors that are connected in a tree-like structure, where the number of visits and codes per visit can vary across patients and visits, respectively, but the number of codes is the same as the number of values for any given visit.
This package helps you work with joint nested ragged tensors in a way that is similar to how you might work with dense tensors but dramatically more efficient in terms of memory, time, and disk usage than if you actually densified the ragged tensors. In addition to the generic speed-ups obtained by using ragged tensors, further (minor) speed-ups are obtained by the fact that the ragged offsets for these tensors can be shared across levels of the hierarchy given the joint, nested nature of the tensors.
Usage
The JointNestedRaggedTensorDict
Class
The main class in this package is JointNestedRaggedTensorDict
, which is a dictionary-like object that stores
a collection of joint nested ragged tensors. You can create this from either (a) a dictionary of raw lists of
... lists of values, (b) a filepath storing a JointNestedRaggedTensorDict
on disk in the HuggingFace
safetensors format, or (c) a pre-processed set of tensors (this is likely not useful, but is used internally).
>>> from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict, pprint_dense
>>> J = JointNestedRaggedTensorDict({
... "T": [[1, 2, 3 ], [4, 5 ], [6, 7]],
... "id": [[[1, 2, 3], [3, 4], [1, 2 ]], [[3], [3, 2, 2]], [[], [8, 9]]],
... "val": [[[1, 0.2, 0], [3.1, 0], [1, 2.2]], [[3], [3.3, 2, 0]], [[], [1., 0]]],
... })
Here, J
is a JointNestedRaggedTensorDict
with three tensors: T
, id
, and val
. To see how these data
would look were you to densify the data, you can call to_dense()
on the object. Note that, just to make
things easier to read, we'll use the pprint_dense
function to print the densified data. This is defined in
the nested_ragged_tensors.ragged_numpy
module, and is not part of the JointNestedRaggedTensorDict
object.
It does not modify the data, it merely prints things nicely (and replaces empty lines with a '.'
character
so we can use doctest
to test the python code in this file
to ensure it stays correct without needing a bunch of <BLANKLINE>
s).
>>> pprint_dense(J.to_dense())
dim1/mask
[[ True True True]
[ True True False]
[ True True False]]
.
T
[[1 2 3]
[4 5 0]
[6 7 0]]
.
---
.
dim2/mask
[[[ True True True]
[ True True False]
[ True True False]]
.
[[ True False False]
[ True True True]
[False False False]]
.
[[False False False]
[ True True False]
[False False False]]]
.
id
[[[1 2 3]
[3 4 0]
[1 2 0]]
.
[[3 0 0]
[3 2 2]
[0 0 0]]
.
[[0 0 0]
[8 9 0]
[0 0 0]]]
.
val
[[[1. 0.2 0. ]
[3.1 0. 0. ]
[1. 2.2 0. ]]
.
[[3. 0. 0. ]
[3.3 2. 0. ]
[0. 0. 0. ]]
.
[[0. 0. 0. ]
[1. 0. 0. ]
[0. 0. 0. ]]]
Note a few things:
- The densified result contains a
mask
tensor, that indicates what values within a given level are true data elements and which are padding (True
indicates the data exists,False
indicates padding). - Padding is added to the right by default, but can be added to the left by setting
padding_side="left"
in theto_dense
call. The padding value is0
. - Each level is only densified to that level of nesting.
Slicing and Operating
We can also perform certain operations on the JointNestedRaggedTensorDict
object that act like operations on
the densified view, without actually densifying the data. For example:
>>> len(J) # this is the first dimension of the shape of the densified data
3
>>> J1 = J[1]
>>> pprint_dense(J1.to_dense())
T
[4 5]
.
---
.
dim1/mask
[[ True False False]
[ True True True]]
.
id
[[3 0 0]
[3 2 2]]
.
val
[[3. 0. 0. ]
[3.3 2. 0. ]]
Note that the __getitem__
method on the JointNestedRaggedTensorDict
object returns a new tensor that is
the sliced version of the object on which it is called --- this is not the same as accessing a key of the
dictionary or of doing a raw slice on the densified object. In particular, note that the shape of the output
can differ, if we no longer need to include an empty dimension of padding, as in the example above for val
and id
. Only valid slice keys are accepted; not strings:
>>> J["T"]
Traceback (most recent call last):
...
TypeError: <class 'str'> not supported for JointNestedRaggedTensorDict slicing
In addition to slicing, we can perform a bevy of other operations on the data, such as concatenation, stacking, squeezing, unsqueezing, and flattening over selected dimensions.
Tensors on Disk
One of the most powerful aspects of this class is the ability to naturally work with these data from disk in a
fast, scalable fashion (powered by HuggingFace Safetensors).
Tensors can be saved to disk via the save
method, then used straight from disk, while only loading the data
that is needed in a lazy fashion, by constructing the object straight from the saved file path:
>>> import tempfile
>>> from pathlib import Path
>>> with tempfile.NamedTemporaryFile(suffix="nrt") as f:
... fp = Path(f.name)
... J.save(fp)
... J2 = JointNestedRaggedTensorDict(tensors_fp=fp)
... J2 = J2[2,1:]
>>> pprint_dense(J2.to_dense())
T
[7]
.
---
.
dim1/mask
[[ True True]]
.
id
[[8 9]]
.
val
[[1. 0.]]
In this call, only a fraction of the data stored in the file is actually loaded from disk to materialize the final tensor (in this case, it is quite a large fraction, because our tensor is so small overall, but in a larger tensor this is more significant).
Performance
Performance over time on various aspects of an approximate pytorch dataset using this repo can be seen at https://mmcdermott.github.io/nested_ragged_tensors/dev/bench/
In older commits (see the GitHub history for more details), you could also run python performance_tests/test_times.py
for a comparison across several strategies of using these data. A
re-introduction of this feature in a more user-friendly format is planned for the future, in concert with the
tracking over time of the performance of this package documented at the above link.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for nested_ragged_tensors-0.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 991f8da5fbabca3986b3c6b70007944116bb189ddd5cb8c3de5a3e218093e615 |
|
MD5 | 65c1b77520154c0de46c895c63d2006f |
|
BLAKE2b-256 | 36b56910630f2756934657401f5588d78b91afddbe3a5d0af1319866596e1bbe |
Hashes for nested_ragged_tensors-0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c7a22a43cdc53172c7608fa64f50e21fb628cff73866702dfc27ad05e1a3c40c |
|
MD5 | f19f980b3cbbcfb3f1d7afb99c36a1c3 |
|
BLAKE2b-256 | 679d08c3dd6db4d22a84a42f8cceb6289a70d49183a2d14635ce4dafa237d07b |