Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.9.14-cp311-cp311-win_amd64.whl (222.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.9.14-cp311-cp311-macosx_10_9_universal2.whl (282.6 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.9.14-cp310-cp310-win_amd64.whl (222.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.9.14-cp310-cp310-macosx_10_15_x86_64.whl (224.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.9.14-cp39-cp39-win_amd64.whl (222.4 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.9.14-cp39-cp39-macosx_11_0_x86_64.whl (224.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.9.14-cp38-cp38-win_amd64.whl (222.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.9.14-cp38-cp38-macosx_11_0_x86_64.whl (224.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.9.14-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 1d2dcb3d73a06c50e22c9e0e4d7e881746f2fa4b66f4fe32944ec05ea2d36dbd
MD5 5ddd5ed61be38115b1229227e3b6866c
BLAKE2b-256 5e8da085e5460e59aa56a3bd6f8202033029a22c5537370c2a7157bac5074ec8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.14-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 bb622f557fb31dda61aa01193a4adbd1ac1c5d37814ddae0c628dbdb00e5706c
MD5 ed3eb34dfdc9fda2a0d3c86d0d6c2cdc
BLAKE2b-256 da41ac42a960837d4df4bb759c3690ed71343ebc59f77ea9aeb67e977bf54c5f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.14-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 9a910b2f502dfd8c6732743cd9ae1d4b97bb9a30bb7c53359a686927cc91cbee
MD5 b13d9f12e8398793ced37fc772133604
BLAKE2b-256 308f075e73991c29a2460211bba8245c8319dba1bf3217574d21428423bc564b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.14-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 f9f0df88bca64c3ae0b6d70f8205b6cf1743f052cfb21c29050a6738f572ebec
MD5 85fe48ba5b8fdf6d4a6a275ec61fbea1
BLAKE2b-256 bb80c85135f5ed88d7f44263e25e529c2e736fad3ce17b4b7aae435735079a07

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.14-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 a21b43c46ff8ed4eb1f04814244dcbece9ff319f4ab68e661c0a0456bb766da8
MD5 7701ab0a2834e01e2b8e70656eacc3f8
BLAKE2b-256 9714b37a62d0abae5177d7f72bb2c53a72ca824f6a7e7e6e39ea7921fc441360

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.14-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 d39dedc122cccfd15d974512242c79ea67a3b4ee6aa712c022a08b2adfeebb7e
MD5 aa6ca1dbb0dbe8c7b4d5dea70bafbd75
BLAKE2b-256 0b75ea36613604b7d3d9c286b316135c12ec30428afd615365972ac3cd6dbc86

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.14-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 f22403797b33483148bd0691ee6968fe8a260e9daebcb0809911e004eb90c4eb
MD5 52e993302898ab92509bda30c17238ef
BLAKE2b-256 0f8e0a8386ad2b100ea2bec2372b750943071bd932add3ba63a111f1e23d23e2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.14-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d095749d4125fbeb62c67abf088f4324518e46e67e4505e8c67b287ed2f0661e
MD5 2f08bc783ac42b80ddacd5816f99795f
BLAKE2b-256 888875ada1ae2aea56c946a37329c2003d61f6e0287e72d360061dc184601fd9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.14-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 1d68cb6fc60bf9f5d4aa524f3d5faa99c3c8b2b03cb92f499b94dd22696cc094
MD5 b9cada904b284e315cad283a9c08bc9e
BLAKE2b-256 e576e0561a7f1dd896f1cb9ae840f9daa12969ad4f41dab3888ad5ac3a0d74a0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.14-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 79b8ff0d8ade367a135ac6734fcc4271a92931a70ac33b73e9a1e6c2356e7808
MD5 056f034071484ace5de973e62c801222
BLAKE2b-256 138942ea23fff72f4973d5427fec3d9c892b23f0d7caf4e52e44abdceb42bd51

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.14-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 426241032e993af274db37ca9fe4ed027c1b08f19400267b446ec99d1c633b91
MD5 955c65849c8beeea8c5ac671c020437b
BLAKE2b-256 2ba05d7da22e45f85c9a5cc18ae333a127ca1c667e4a0924e0d413e2bb15c6ae

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.14-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.14-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 7c309a84572fd335c3a13acc43c4030138f576fa0066189b70aba907058905a3
MD5 c1f17fce497286effa54ff6f35332f01
BLAKE2b-256 2e3c71aa57e7e12db429164bdce0a02e43fb8c7405852ca1a4629a4b642cfdd5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page