Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.1.24-cp311-cp311-win_amd64.whl (260.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.1.24-cp311-cp311-macosx_10_9_universal2.whl (320.0 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.1.24-cp310-cp310-win_amd64.whl (260.2 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.1.24-cp310-cp310-macosx_10_15_x86_64.whl (261.8 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.1.24-cp39-cp39-win_amd64.whl (259.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.1.24-cp39-cp39-macosx_11_0_x86_64.whl (261.9 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.1.24-cp38-cp38-win_amd64.whl (260.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.1.24-cp38-cp38-macosx_11_0_x86_64.whl (261.7 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.1.24-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 5a0f5048f527783087dd6b889bb4f0c3f4c3c636cc48c84f730008d2f7c58f39
MD5 042b9a3a6a2b97d89ad42a80c44163bf
BLAKE2b-256 acc19c1abea5500f27719e76fe57275bda7a77cadc93594cbbcc0bc1d2e7c650

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.24-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 db691704bf27d98aa5f6247c8e1f55f0bee792aa6a992036569797ddc5794a54
MD5 f3a6c793cd8e9da11a7b4c1916b60ea8
BLAKE2b-256 65d0017dc2d633c183ec6500f34826e1a02674c68184c16cffa717359b49f977

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.24-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 c2869054cece20170391dbf7a3427b936542d4c93be4b9c837587e2141b5b946
MD5 3ca71629d48dec72aff169473eda9500
BLAKE2b-256 1e74ab907ff4f28286fe4e0cb87e60fc5038b41f15495ff9b2f9d07f376f9780

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.24-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 e04bb4f8ae9fd99e1eb55c72584defa517f01be80e3a466105fb377ec16e6079
MD5 9c3f4babb61354d85c359f145c27ca64
BLAKE2b-256 97ae8a6db6aa9fcbc111811ff78cf578d69019c2749b75ea78edaaead0e5f97b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.24-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 6b7cc9f7a36b05f5a07074f9e31079c0b1eae21818f2884f892c3fc87ab1deee
MD5 b9486335c052f58f1d39745bfd00a975
BLAKE2b-256 12e4ea64bb6d478f6039bdf6d779cd5a56dbe89a7c03713d02b81a1ce43ce8bf

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.24-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 71d782332e9cfe78f5054c82b017fa50c6b608de0fbf83d037111d58f3d275d8
MD5 ea670057b56f66bd4e4199c379068649
BLAKE2b-256 e3fe59bc388ddc48a95f34ce5b930beb5eff5a4476a1defb44001c2a8fe2e552

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.24-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 6d4f006fd426821d2b4a8ee8e516e7ddd73d4a71614fa01b3f16c419a4369c51
MD5 0a2d418f8de0c329c9cfb6045ddbe927
BLAKE2b-256 5d4f27e7174317fad1a5e4e95cab14ca7e02764752f73dcbf52146f01e8084a4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.24-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2cf80f7d77f93ed56a59ad9b0301f290046f8407e5dea910894b457ec68dbfa8
MD5 acc17d744a4d785c8933c124ceff0f76
BLAKE2b-256 ad49bcb5151402f66f800d57f970a272d97d06f4bd6ac24b18f2a98583c98f43

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.24-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 ea535e2632925873458395cd3f99938b29a58b0659deb397eaec7faf774110d7
MD5 22101c4cb9b91430e77e0b3d0c299bf2
BLAKE2b-256 c031904c1b3932c6fc7983507543ba0518ca9bd8b3593c94a761cbc3540b59bc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.24-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 84c1fffdd92c0c37c17b668bad19b800a6220b1a8fa035fbae341e2c32b8f2b7
MD5 7a90aa4c97216b35bad4d95adfe60a8e
BLAKE2b-256 ce6f19d42f086ab61b97e618005d66d1d0895b4de63d9f3b0c41c1e8b28d1bed

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.24-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 4ef877ab03ba823d5704bb1a9eb21a5c4b95b0fbff8e6ddfd642e54d339013f1
MD5 c8430acbf22661d8025f67712fce6853
BLAKE2b-256 71850de2a1d395a4b4489d44f461895ac156a978e192cdbd6cc7a3037139584a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.24-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.24-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 d42bebda1431d6d2d56c30a259466599a1b9bcf13966e8e9eeee1d2cf117629a
MD5 5b5e2fb76a20d80f7f220cae165d525b
BLAKE2b-256 0e31a57b10957ecb5fe7b59ce04240bda63b40ecaa9b0615bb70dca3d61fe959

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page