Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.1.20-cp311-cp311-win_amd64.whl (259.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.1.20-cp311-cp311-macosx_10_9_universal2.whl (319.1 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.1.20-cp310-cp310-win_amd64.whl (259.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.1.20-cp310-cp310-macosx_10_15_x86_64.whl (260.9 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.1.20-cp39-cp39-win_amd64.whl (258.9 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.1.20-cp39-cp39-macosx_11_0_x86_64.whl (261.0 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.1.20-cp38-cp38-win_amd64.whl (259.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.1.20-cp38-cp38-macosx_11_0_x86_64.whl (260.8 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.1.20-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 20c8eb95589b3366d74c5e26a1129621065ec8d58dc579021ff0d5592c09cc69
MD5 9b74754652b52974c678e2fb0e124d84
BLAKE2b-256 48ce0c2a151a4375574ab180ef80a00309809a50daf2830859cf47529d7c5387

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.20-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 82a4e58161bbe7095fbee42d11bba2b09d5620e5cf9abc309009ef30d4f82282
MD5 e72c72ec23641fab28fba2879151c4f7
BLAKE2b-256 ca3212d14726e4defcee7ec4ba07dae9eb2f32e14d30fdb6777d4ba49fbccdc8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.20-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 062ade92db3d07d8cf3e566addfc97085d0cf0b290f8c71725ba73ea7d1cedfd
MD5 8847083cfeaacb501b4e01c7c2acbd1d
BLAKE2b-256 128381288f426dfcda5963c81236dc569260c2380aaef36c572134aaa92c1c76

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.20-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 247676e754d74a28224e9655d562a325091f4ee7b8c7a542fd29b0a430e67090
MD5 b22ab5ec00b417a4946c9a78d5ce4255
BLAKE2b-256 207affb70e40b313cafea2c8ba6474ae143bac5daa7deb62754f3055cd4cf4d6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.20-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 41b940bc0aa9f9b3d4b559097a0fd27561610a7de2e937e3849b50a4527e6c7f
MD5 31dc8ff49f63e51d6d361f55c7f83623
BLAKE2b-256 c0a0aea2b0c2726bd520d004480177a13f9820f03b45c86062c8bb9adea3edb2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.20-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 65e4276b766a08145a6435cf5792ea19a4500e38bc72d9048f6d0780ca7fca77
MD5 b8c5c8dca049ecbcb1809826bb241c32
BLAKE2b-256 7e03958bd735d7a2ddf21c4557a890fa821b45e34eeca341421760e5b648a855

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.20-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 28e7956489f1f8b2951d7260fa232ae913e3e3700dac0010b6423810e607e278
MD5 dc858c0f0b90131c63b3d74aacb80d38
BLAKE2b-256 5733a85b80cfdb101ec148705902b7c54bcb4f3b99357da28ab49f5bb73cdc2f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.20-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 f71271422f8a8967f6f06622bc97f665b8780331676aabc4ec0d3f702baf39fb
MD5 475d181bf2af9d1165631c4641c29328
BLAKE2b-256 7eecf424ccb2a9bf76efcae0632d807abf9b380317f70c82fd9638e427f82f64

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.20-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 ce4d1c92e06fd75d6ccbb84c553c80390cd2f3e04f4ac01f332a5acc016bdd6c
MD5 8b407e6440a76e4a50de5d302a418516
BLAKE2b-256 51802f7d9a25104170bc4d3c113d59fd05e4ce1f12a95c39e1eba2988b2ce34d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.20-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 780b97870d59493bc536046971af80283061d2671981a89b93931b2f3b0f705d
MD5 75a5e55c7712c2f592ea567876b3b256
BLAKE2b-256 b7c8657d626dfd08ca2772cc11e63900b0f87ed6bb1ef568252d5a61cf03ad9e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.20-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 a9c51a201502143570b9587989aafb2d80bd7075292eb130c16463f875a1e13d
MD5 90d0ffdb8902c99f3653f062e8c607d5
BLAKE2b-256 06d71fe731a703a0964dfbe8cd7f079d5ce6dec1dc005783ba88aa4dd483132c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.20-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.20-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 d84dd64d36b3a41a9b440ed55691fac93fe11d09c14f03934baec2b0e1d524e4
MD5 457d7895b9b03322f83650ae0cb16804
BLAKE2b-256 7de629cda6395ddef8495be028bdbd44293b4a3ac2d965b127e8b9eed4d75983

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page