Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming | **TensorDict for parameter serialization | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings. Whenever you need to execute an operation over a batch of tensors, TensorDict is there to help you.

The primary goal of TensorDict is to make your code-bases more readable, compact, and modular. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General principles

Unlike other pytrees, TensorDict carries metadata that make it easy to query the state of the container. The main metadata are the batch_size (also referred as shape), the device, the shared status (is_memmap or is_shared), the dimension names and the lock status.

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device.

Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")

When a tensordict has a device, all write operations will cast the tensor to the TensorDict device:

>>> data["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert data["key 3"].device is torch.device("cuda:0")

Once the device is set, it can be cleared with the clear_device_ method.

TensorDict as a specialized dictionary

TensorDict possesses all the basic features of a dictionary such as clear, copy, fromkeys, get, items, keys, pop, popitem, setdefault, update and values.

But that is not all, you can also store nested values in a tensordict:

>>> data["nested", "key"] = torch.zeros(3, 4) # the batch-size must match

and any nested tuple structure will be unravelled to make it easy to read code and write ops programmatically:

>>> data["nested", ("supernested", ("key",))] = torch.zeros(3, 4) # the batch-size must match
>>> assert (data["nested", "supernested", "key"] == 0).all()
>>> assert (("nested",), "supernested", (("key",),)) in data.keys(include_nested=True)  # this works too!

You can also store non-tensor data in tensordicts:

>>> data = TensorDict({"a-tensor": torch.randn(1, 2)}, batch_size=[1, 2])
>>> data["non-tensor"] = "a string!"
>>> assert data["non-tensor"] == "a string!"

Tensor-like features

[Nightly feature] TensorDict supports many common point-wise arithmetic operations such as == or +, += and similar (provided that the underlying tensors support the said operation):

>>> td = TensorDict.fromkeys(["a", "b", "c"], 0)
>>> td += 1
>>> assert (td==1).all()

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = data[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(data.view(-1))
torch.Size([12])
>>> print(data.reshape(-1))
torch.Size([12])
>>> print(data.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = data.apply(lambda tensor: tensor.uniform_())

apply() can also be great to filter a tensordict, for instance:

data = TensorDict({"a": torch.tensor(1.0, dtype=torch.float), "b": torch.tensor(1, dtype=torch.int64)}, [])
data_float = data.apply(lambda x: x if x.dtype == torch.float else None) # contains only the "a" key
assert "b" not in data_float

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> params = TensorDict.from_module(model)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> with params.to_module(model):
...     out = model(x)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> def func(x, params):
...     with params.to_module(model):
...         return model(x)
>>> y = vmap(func, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and (soon) torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

TensorDict for parameter serialization and building datasets

TensorDict offers an API for parameter serialization that can be >3x faster than regular calls to torch.save(state_dict). Moreover, because tensors will be saved independently on disk, you can deserialize your checkpoint on an arbitrary slice of the model.

>>> model = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3))
>>> params = TensorDict.from_module(model)
>>> params.memmap("/path/to/saved/folder/", num_threads=16)  # adjust num_threads for speed
>>> # load params
>>> params = TensorDict.load_memmap("/path/to/saved/folder/", num_threads=16)
>>> params.to_module(model)  # load onto model
>>> params["0"].to_module(model[0])  # load on a slice of the model
>>> # in the latter case we could also have loaded only the slice we needed
>>> params0 = TensorDict.load_memmap("/path/to/saved/folder/0", num_threads=16)
>>> params0.to_module(model[0])  # load on a slice of the model

The same functionality can be used to access data in a dataset stored on disk. Soring a single contiguous tensor on disk accessed through the tensordict.MemoryMappedTensor primitive and reading slices of it is not only much faster than loading single files one at a time but it's also easier and safer (because there is no pickling or third-party library involved):

# allocate memory of the dataset on disk
data = TensorDict({
    "images": torch.zeros((128, 128, 3), dtype=torch.uint8),
    "labels": torch.zeros((), dtype=torch.int)}, batch_size=[])
data = data.expand(1000000)
data = data.memmap_like("/path/to/dataset")
# ==> Fill your dataset here
# Let's get 3 items of our dataset:
data[torch.tensor([1, 10000, 500000])]  # This is much faster than loading the 3 images independently

Preprocessing with TensorDict.map

Preprocessing huge contiguous (or not!) datasets can be done via TensorDict.map which will dispatch a task to various workers:

import torch
from tensordict import TensorDict, MemoryMappedTensor
import tempfile

def process_data(data):
    images = data.get("images").flip(-2).clone()
    labels = data.get("labels") // 10
    # we update the td inplace
    data.set_("images", images)  # flip image
    data.set_("labels", labels)  # cluster labels

if __name__ == "__main__":
    # create data_preproc here
    data_preproc = data.map(process_data, num_workers=4, chunksize=0, pbar=True)  # process 1 images at a time

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    data = TensorDict({}, batch_size=[])
    data["a"] = torch.randn(3)
    data["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return data

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

data = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

data = TensorDict({}, batch_size=[N])
for i in range(N):
    data[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = data.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = data["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.4.15-cp311-cp311-win_amd64.whl (290.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.4.15-cp311-cp311-macosx_10_9_universal2.whl (350.0 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.4.15-cp310-cp310-win_amd64.whl (289.9 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.4.15-cp310-cp310-macosx_10_15_x86_64.whl (291.8 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.4.15-cp39-cp39-win_amd64.whl (290.1 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.4.15-cp39-cp39-macosx_11_0_x86_64.whl (291.9 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.4.15-cp38-cp38-win_amd64.whl (289.9 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.4.15-cp38-cp38-macosx_11_0_x86_64.whl (291.7 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.4.15-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 b3a9395bda4f6cb0c28c8dadcfbbf468701e53455e4025fef1cf54799981d55a
MD5 c8592336b9e1cfacbe5c4bb1c33a3b62
BLAKE2b-256 fe8c64b16782a67ba0f3c78bf868d39ffc9f4eec74a466cac2c6c355431db4a6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.15-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 70d7e4d66ac8e37c284ae667126954922afe6571efc44b6eb9ef15e09a95f072
MD5 39ead0e7cbc535de21d95808ecc0a84f
BLAKE2b-256 da39576f28bfb80c89038a495133d8ecfa08d9d66489e4c0c30e157744180bb7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.15-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 c2d8ae6dcc66508a3b2aa880830712db90d8cf8ab188973d976742a139043eeb
MD5 6c7cd35d75f43933ac698ad278ad2ab9
BLAKE2b-256 10f81248360bd5d3dedaecd969d328041ec61a0616cb11eb17020041deca4420

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.15-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 641b49691af20cda295e57ee7af82d382b3d3ab3c6667c85b2fa1aea1a94e9e6
MD5 48835a688b87d2cf13e3261752939f92
BLAKE2b-256 643243b2e2042611afc87476a8adf508cca6959c3990b5fa77500213374d858f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.15-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 0a57dbba931f08874b7ddb958323f9d68c8948c8819374b0b458161ca72822f8
MD5 503ff9f77082541b4690b4c26c5fd7b5
BLAKE2b-256 aa5a3bcb07e425ef8da01d9e8c701d52fd1a5d1616c76fb42d82a899f7d298bd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.15-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 10b931d49057e090307c6ff1412f24366db8c3131e28acc884e4f8b9765fc36b
MD5 025913f8699cfa1baa975aac9254e7d3
BLAKE2b-256 fe4693522f7b9c452f67a08f9fe0585e98baabb64d42ce290156412937e9a561

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.15-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 fd532d54efedc1d46dbfe00cabd0bc2085b4709d4c28de179f77e64ba539a835
MD5 221ecf0cf2c44b9208ad19807e09a9b2
BLAKE2b-256 94f88b711cd3d91a05f37e84dfcd4db377acd666ca8ed79d40956f800c4cb6d1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.15-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 30beec0869d6363144615a37261eb5185ffbc18413f772837045fc971f769273
MD5 2b1a2471adf8525eeed0fa725efe95f2
BLAKE2b-256 443306bd5980b0e2b0d2a35c99803e3250f9e4c15b9c701fef8c687fc86c79b0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.15-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 75e8d590615592f8644d9a2d188c2b3a6e66d2331a42d38a73e8a08e19b6a1ce
MD5 699f092ad8e471e150277416e1226454
BLAKE2b-256 14d1bc32a61353755a026c6c809bc0d915f1696c2424706c3bee9b081a235989

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.15-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 9f35e6e0169f8eeea8d51548f63a4aee4354e07911c3479c2c76c0431c9409bb
MD5 fc77e2cdd6a744b83a7ac6fe76938658
BLAKE2b-256 d40c558b4591e4a47dfd4da78ad71393e0a552aa09919fe2f2b492d518bb6c00

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.15-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 aee86f8652d92c2ebee33567d4db6bda82a57b6aa8f4fb8dc36eb10f589294d5
MD5 29e9e86e8f5ac7c63ef6aa8522da7bad
BLAKE2b-256 8755152d8cb3f4898f5d55d3966e94a6a146b6e9f8ba92d43344b0429b032cac

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.15-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.15-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 a6a87671bfe3a2466a3513a9083b048f58f9f13b96548dba948142c9df7afa3d
MD5 9d8c317f3ec8df0ad93d32d22cf88485
BLAKE2b-256 41d386b41b8dc74590f87daf09a87be34784ceab26a4ed48a2ce6117390de228

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page