Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.9.16-cp311-cp311-win_amd64.whl (224.6 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.9.16-cp311-cp311-macosx_10_9_universal2.whl (284.3 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.9.16-cp310-cp310-win_amd64.whl (224.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.9.16-cp310-cp310-macosx_10_15_x86_64.whl (226.1 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.9.16-cp39-cp39-win_amd64.whl (224.1 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.9.16-cp39-cp39-macosx_11_0_x86_64.whl (226.2 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.9.16-cp38-cp38-win_amd64.whl (224.0 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.9.16-cp38-cp38-macosx_11_0_x86_64.whl (226.0 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.9.16-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 8cdbbfab75c7532b65cb46e69130c478af7a9c5d7fd338d30bd0825f67059627
MD5 b5efe81b11ea221c277b9d0cce361531
BLAKE2b-256 9a9577594dbfe9407825ee5d4469b78fc1e2e13230b51c9345d71d98b17b1b90

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.16-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 08358d7e34d0fcf0fde196a4a7c4fb653c6735a3fe46f589123c2fafe80eef96
MD5 6ce1330a3bc01c36e9e57df9c0b06359
BLAKE2b-256 035d9e191f914a5f552d2ab169a9aa33dc9073809147107dc34c5ee0d55112f1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.16-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 d393a5d4e245a50b51095c548c7ccadee23c7e2879081c1bfc0ea225524c86be
MD5 01c596ca460017d6b5d2bf7d073607af
BLAKE2b-256 074513552cefce5565262108858c25a7a53613bd0139cb4e6077bcd057a0f282

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.16-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 11eeb2584ed19fc3bc0a0d23783182f93f19d8e40ae3073b12a88d1c98451aff
MD5 295e093c427d639b598eefe0125f77c6
BLAKE2b-256 bfd0bd2175e63690d5e1082254dd95aa01f3a76935b278dd51002b8e93cf882b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.16-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 3268a005732b9301a68e2d25f18906436ea5f134dcf69851610cb589bf163dc6
MD5 9ed5c7ca48a29cae1be3728db50a19ce
BLAKE2b-256 590d891c5c178da5dfd1c281c7b944fe00479f34daed0a8f3eb5a258c2990a55

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.16-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 18632d924a37ec9cf6382cf3b4d4a0f3dfdf40a670523f5cde8008d9a0b267aa
MD5 020eb78573dcd5dd549bfd83be966347
BLAKE2b-256 c8d4f29d423747ff2dfa1ad20a3d551fafc0c85d62f2ca3a947f796438f2074c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.16-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 5e602f890a93d3ac5143a687462ee6be479f2f9daeabf90275523e4132fad633
MD5 e919e01dde49deabff455553a8b8376e
BLAKE2b-256 dd8877a3dfe186209b88e3243a6cc7336878e5c9df9f2c71a8c91b336589c17a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.16-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 1f90557a659fbbb41015e7bcc397c44fcab8fb8445c26d05cdeb592f127b4849
MD5 6cea4406f88dac96f3c78a38e744ff96
BLAKE2b-256 4d2fec408fcfd7a16f0dd69f4cb0688784566c228f0c2068dcbfdcf78091935a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.16-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 7885e276dd63652066ec311a41f27445923740f0011b6a457f3e91810a554f6d
MD5 681ae8a23712b6ca5c882e1ed344e1f6
BLAKE2b-256 aee64747ee031a9cc5076610a10bb02bddc15b8e6b55dec423e9e530439192d0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.16-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 a27f00e967607c9e5835a8bbd7d65ef765b87d424e6ef633f92541d737481e68
MD5 ee9773d8953db4cdf480ef9dfcbec156
BLAKE2b-256 dfab2e1d560282bed72ecaeca731c1d9a95941c4ca230be53001dc2ebdb260c2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.16-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 58314ba3f49125bcd48045b4cecd683cb9ab44bbe65ed77cea9f39c0e42a51b6
MD5 1c8384354ba36ddc068d1f123096b79b
BLAKE2b-256 bdfa325f0d25379d2852a8b933e3b16e42809e66a58352de21eec828a8872547

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.16-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.16-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 7de2f157fe9f62b20e8ea19891c5dde37647e46721cbdf4c4ca1fce264d61c7b
MD5 770781449accb93a97b6d119825aa1c4
BLAKE2b-256 aa5d824638739cc31742fb139d66d4449efa8460de74847e58e36416bc18bcb4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page