Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.12.26-cp311-cp311-win_amd64.whl (253.9 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.12.26-cp311-cp311-macosx_10_9_universal2.whl (313.2 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.12.26-cp310-cp310-win_amd64.whl (253.4 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.12.26-cp310-cp310-macosx_10_15_x86_64.whl (255.0 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.12.26-cp39-cp39-win_amd64.whl (253.0 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.12.26-cp39-cp39-macosx_11_0_x86_64.whl (255.1 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.12.26-cp38-cp38-win_amd64.whl (253.3 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.12.26-cp38-cp38-macosx_11_0_x86_64.whl (254.9 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.12.26-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 65d2e3e68cc2b50bc61213f065a1b57a44b79547fa305ab1c76c2afa9898c5ca
MD5 bda63e0607b28b52151eb25c205b28bd
BLAKE2b-256 23049839c2c0008491fd2190b0c5d9cee2cf8cbc18689797c605cd5cd378663e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.26-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7941defc6eed2543f706692d4e3ea28c769521fccc69944af9480931d5053035
MD5 23a246824e1da315c2a4c56e1934c789
BLAKE2b-256 d560c19cc5e532899fbab74c4bf0ebd0ea38391af99dd7e19743531e75ad30e3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.26-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 4390e946299dea301858ecda8ca9a7aab24c9e36be9837d636d32d619d4c7f25
MD5 d4d397c8ca9c6b22dac4729e7fb5c783
BLAKE2b-256 8d8ec917dffe27f45eed22ed92f86cb62d0b50a635f141ad13ce263af3f96255

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.26-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 f27adf04be165760886de7c3f1b30e731370d7f5fa96c94c45c3e616aabd2f9a
MD5 67a2b2222f3ed551a5b143e77753c89b
BLAKE2b-256 a061de0ac702e3990875da9c71ee2ef4b66aa6d12020202beb7467ebfa734515

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.26-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 8734e7fbea29657e6d31afd9bccb958bcff89c7dec56d414437d77437064060e
MD5 56eb810ef61ac84fd5b9a0ff1f1b8378
BLAKE2b-256 db0f023bff0e0e32b5da58a4ea198003ca1271f7c1816649ac1409dbc53d04d6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.26-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 4ffd8d5c2fd39fab1e3e44b0f5456e43ed3116341f47b49ec1501218ee78f1ad
MD5 7088f2418bc6ac250500cbd0b408313b
BLAKE2b-256 c164b0f1b29b492016f90da410dcb1b10776580592c70436e18044bc9b54c351

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.26-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 9adb3612375d5fb83b4ba68358c55db55872d03ae2275d9653ad1eb0f3275cd1
MD5 ec9bf9598ca0183d40d0e6a33bd45123
BLAKE2b-256 4e5da433cc4c789bb7b46a6c86b8829e900b42d91f1105e587ed38e4e27639db

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.26-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 51ab1c23589cb0eaaef5a9821b3b0f75d1f0a82f42b94873ee5bb20aaf846b0c
MD5 2183883bc7f605403f163f133c2ed7cd
BLAKE2b-256 81084a279e3e7c745dea6b75c2b0808500fc030768e75410c16e7bb4b0e3576a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.26-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 0701b7117142154cdaa07eeb092f4324269e5fa2d9b6c60448c8389ce1b3a4de
MD5 53bfa7e89a8c76b7407603f0d64a3c15
BLAKE2b-256 079739e92041319dbf54a3bd09ac41a7b6ac2e412da74161b59b26ee9edd3980

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.26-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 acfc43b376eb2e4f5f063dca8ad9b8144764062918e29f849864601bc55446fb
MD5 234524433a15eca308bd694b8e455a93
BLAKE2b-256 0d0302e282bf9a89c54b3cb4f03483065b7d5c9e28637b18ba239228b011cbad

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.26-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7084bccaa4800c3f76d04eab9f5329ae259f05497a35b769ccf804a8d70cfe82
MD5 f00deab39a166d2d9523d46b05895eb9
BLAKE2b-256 e188d23d870fcc815ab3745bee991afb2021566a78ce7b3ce9babb7c792b3ef1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.26-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.26-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 59cc4644cf48a5b98318be5d00b1b01d3fa0d5db83691fcc0c1435d24d309c20
MD5 15d0e3354a03c4d50bc02d7e5e42681e
BLAKE2b-256 8343def5f5b1193ce6fdef864779907ab21fe6701f9b5e0156ef5c3654667650

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page