Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.1.27-cp311-cp311-win_amd64.whl (262.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.1.27-cp311-cp311-macosx_10_9_universal2.whl (321.9 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.1.27-cp310-cp310-win_amd64.whl (262.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.1.27-cp310-cp310-macosx_10_15_x86_64.whl (263.7 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.1.27-cp39-cp39-win_amd64.whl (261.9 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.1.27-cp39-cp39-macosx_11_0_x86_64.whl (263.9 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.1.27-cp38-cp38-win_amd64.whl (262.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.1.27-cp38-cp38-macosx_11_0_x86_64.whl (263.6 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.1.27-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 fe37a22fae196df2bc4ebcb91bf13b1d92aaa3f9e41899c62752a3f26955714d
MD5 2d52feec17320a9904509aa060d9466f
BLAKE2b-256 fe5823fa9f1f7944ec196b5799ce66d1bf843f7a35e8ea30f3fba14b2c4178b9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.27-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 33f10e5c756d803cfeda068f63e664d38f4f2be04f49f2e8e3847157194c2623
MD5 f2ce95ef2fdfc8ccf1183554aa6aece1
BLAKE2b-256 51a18355781be69363685731cc4c2ef7d5f0a2ac10ce6d2b5b86060b2deb04bd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.27-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 e45bb0818b6243193d704d6ca573fd95983afd36b667e5a4ec0bbedd3b635bd3
MD5 0c028d8808219d844bf31cbab0cb3e75
BLAKE2b-256 95b09fcaaec61215c57a44401927bbee1cddbffb3bdb850e97e3cc9c6b9f0579

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.27-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 d9257cc07029fc9b0d2d304d97f431c1f491b50cb07e6f03dfd6b2dcf5ab4b20
MD5 6fe008aba267d3559647fead9de2b08e
BLAKE2b-256 cfc3674584aa603f5350d80535f05934fc89022e5b8be3ab3a8c034e8220a86f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.27-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c02d2607c09f18542d7c2b4aaa9bd80f8ebde9af82d7992c189394403f3b3252
MD5 fa193f847d4f6c6e055f01f55de190ea
BLAKE2b-256 8bb3dd61a4039e6dbd17ab4038f6ee22d32de71456541932da00490b31eccb8c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.27-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 0ac2bba67f77d4b3870d9dc3b685be623e0510c0b51cfdd30988713ca4b456d2
MD5 ddba2ef6a8239fb08304885f633d0d54
BLAKE2b-256 6fed85a420466fac6844688038cf4b2afd1e0f9bd36c13359dc9a823b3431016

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.27-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 2cf19e1518d9a66fa33089a2eaa80d7d619cf7db3fabedd97863d98be47a366b
MD5 4c4847a6e8ca52709e51da0d49068927
BLAKE2b-256 475dfdc0d88d3d75a70f71f2a11224f183cd98c0936ab3d2951309f1fa9869e2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.27-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 a2036d1e6a932ca67b212c69979708fa91188aedb1b8d78fae968f7c1b3414b6
MD5 215d65b223b856a0d2ac1049d59073a1
BLAKE2b-256 2874e798ea0d528df9e34c10b39a92c0c0bd6bca4ca6f59bd97ad810e8da8c32

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.27-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 2eb8dc8982f36d4d4985b7ab944012aec315a7fd851e60f62a712bd1f48cd6e9
MD5 34c94fa0603229888ad8d2f6d20529fd
BLAKE2b-256 75b6ce69dd4d5adaa14413cba617345a7f88ad039820c1c71f2acac6f2a0c3bc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.27-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 8a5dd8e411a0ebbe6c4414924e6d48ed877e9b32309ab1f5f9b45d72cc60f7aa
MD5 bf09df8c065855905a4fb27154b3fc7c
BLAKE2b-256 36f47628cad27d754499df5f0af2dbd98b5967b7a76cb7299a7188e67cfc17ba

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.27-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 db3cc5ccaf8600ef0337e0cf2e16202c772a93a8db37fb3dea3d360afd602ef1
MD5 c48a7f0651082958936a4fd19fd5bcc0
BLAKE2b-256 4f4179fbe7a49d5086e9f1bf6884daa157d3f3dee31eea796e33cb6d9c359cdf

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.27-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.27-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 bcf0e1505929e9bc95be0303a2346a9a6b668a469d0bef80f7f54f8a8c3a87e8
MD5 2592468aa15696eb6d412230b432cdec
BLAKE2b-256 f79cc6b86d025199a1916fb1ffcea0185bf355efaeda9e68d012d63978f64d3a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page