Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.10.19-cp311-cp311-win_amd64.whl (225.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.10.19-cp311-cp311-macosx_10_9_universal2.whl (285.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.10.19-cp310-cp310-win_amd64.whl (225.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.10.19-cp310-cp310-macosx_10_15_x86_64.whl (227.3 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.10.19-cp39-cp39-win_amd64.whl (225.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.10.19-cp39-cp39-macosx_11_0_x86_64.whl (227.4 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.10.19-cp38-cp38-win_amd64.whl (225.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.10.19-cp38-cp38-macosx_11_0_x86_64.whl (227.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.10.19-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 7702b8aa5196b346933dd7a8c7ec1d9be6b68aa4e16115fa878b5115a182f6c0
MD5 fed453508d6d9b0a10150217791c55a3
BLAKE2b-256 0d18ce4fd1405f397924b1a1b0a732e4a9d082cffc838ef1a8d660d71b72322c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.19-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d19ea7030aa1cbf59b623897b7d2c792a1eeb9b2f480ebd18b7a10e1cb5f8f64
MD5 54c309576c55343a3d63d65378cffa25
BLAKE2b-256 841af49879c716da409b6b9fa1e5e3832491a3772f3e4b149649229ac9d7c748

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.19-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 dce56354b72cbaa009485b753ef46a02f16cfb6a48b9e36dd0dba9e93819ca6f
MD5 034f7c83f61e351b3f19776895b3cef0
BLAKE2b-256 f1a5640d5e6590fe2e7a68221e123a8b13fad8ae76a0b262f96a5b83ffc5a9de

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.19-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 03a2af830118ba415e83919fb57d0f03de42aeee5fe17e32dbf22fee0cffb11c
MD5 a62e95827952df7cb2ca2aa974e8e360
BLAKE2b-256 03bdb7a670769e8d995cd684fc476125e2001c9b3b607bd3da609306b49dc178

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.19-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 6557ed02ed0b3b8567006053758a3833a67466dd3012db57949b4da3acece57e
MD5 ddbb606385a2c46453f54be194b3418d
BLAKE2b-256 4323612747e00d06f0b7154f438848aabf269b18b4f6446a7dfc346b5734ed37

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.19-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 1da7e74ce7d904469bc3c92aa802460f8fe54f9a2a4e3521a2151e24a329834c
MD5 d6ee80ba1985e66d269aa5e14595eab5
BLAKE2b-256 dc4109386e3dff0ab9ba4353d46d36591d97065446411d1691a27eff38242f5a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.19-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 089915f7f3d671d82286fd46f8b40c1ed4ee84f2df66d41fca5d5184c38a5e84
MD5 2a5f062da563286f181d6da34b03eee8
BLAKE2b-256 3ad93f7872b0027c0dede6a97331ee5c040f44d8a4e0141c75aa48e9128ff4d6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.19-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b78d1bb15bc08baca7bc47b1f219dc80a64a09c7296fdf0fcbcaecc41d5231b6
MD5 6944cfaf11f61c506b7ce0b022792eb6
BLAKE2b-256 90f967e65fd6c18ffa2886a52479412bb3362819300a6936b2163abe9701837a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.19-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 e31b52ffc19264fd87b03e1be44f1f6039d30fb0360f18a3f960402bd81df9e4
MD5 5e566e22c208cf2997c973fe8690abdf
BLAKE2b-256 e0e439091e021f62f7b9d0dd1a1aabbf7517ac46d5c7c8d12f395e1a3f77c141

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.19-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 4a2859d3020cc170883613d7ec2608019d3409dc5dc5968b5d4cdaeb7ba58fc1
MD5 62f3bbc7539d23a302ebf9aaf7b8a760
BLAKE2b-256 ab1047b47d4f43c3abff6b36f47d1371a956f27979145b5d86072762862f0c1d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.19-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 09a53d3748930dc258ece37cf31c25ebddf3c56d710545b5e39e4bc357d680fb
MD5 068d9863cbd151e819a2c76c5aef01ad
BLAKE2b-256 42c55482067741dcdbd31c3d39252bdb689487d4ea0cd7d886f469d021dbd941

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.19-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.19-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 b4b536b56ee41064057542960c8c95b10ccd3b2713411c895cad090b1ca9b93e
MD5 49d714c15ab7c29493632b66a7fbe87b
BLAKE2b-256 afff2eaf2d680c8a706047209f665cfb2fe5f1ece1de46a9dbfcaf5534b29069

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page