Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.9.13-cp311-cp311-win_amd64.whl (222.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.9.13-cp311-cp311-macosx_10_9_universal2.whl (282.6 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.9.13-cp310-cp310-win_amd64.whl (222.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.9.13-cp310-cp310-macosx_10_15_x86_64.whl (224.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.9.13-cp39-cp39-win_amd64.whl (222.4 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.9.13-cp39-cp39-macosx_11_0_x86_64.whl (224.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.9.13-cp38-cp38-win_amd64.whl (222.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.9.13-cp38-cp38-macosx_11_0_x86_64.whl (224.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.9.13-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 1009fa39a288b66189dca014c0274c0a14a425f8127f0f5f5cd6c6e66e73ef02
MD5 b0e7923c1c6e380eef13a6eb8f032d14
BLAKE2b-256 88e0e51fcdd3db3014f5c0b4712cc02bc9c8b75065db00b5e89812dba810338c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.13-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 1ddc0e647bc8bb27598d3241591d7697583dfc758c5b95fbf4d69b2e9cb0646b
MD5 4a4787b1c9b92f190638deaccc80059c
BLAKE2b-256 40642d42e6089c805be38e5f3565673d1b4899666166ec7d82ca5f9952b4b396

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.13-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 38aa6661747ef2772b8cee84be0eba70aed109f9599072c3b6d149df874b157d
MD5 4736250a7ca597145bae56dc473941cf
BLAKE2b-256 1495ef36a444782bf9c71b4e08b25377a7eaabf55eefc56ce594305d258f2cb9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.13-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 8de593954708842d788ba0391dddb1cf2793e8a109ea7fabd29f750da9bd8d8f
MD5 752642b730544cac352c3f3dd3f1af5b
BLAKE2b-256 8ee5d7c7c71e0ec1316f9cdb9e4fa21766821e38ce8fb3602537ef9ed96690ad

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.13-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d2967f243edbd0afecd15d517fc38c67546d6b6ac6fa5323f84762bda5ea7543
MD5 8e988c160f258508949c18babbe9bdca
BLAKE2b-256 4d1fec603621c22e09f582f8f652880944eca60aa0acce764c7f6d4f5026a806

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.13-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 ee2e37fd99e01dea3d001f702407e8f880600198346820633668f3c32e05642d
MD5 9244af5338e463bd4f2d2ea2d17cc343
BLAKE2b-256 a70aeb9813149d0c7242082c453d4a035d541ada59e3540acd87f642dde6b122

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.13-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 774cece7bf83a9e9317aa4f140f2838fd25fedd2442e7736e0cf5d976960191d
MD5 098b3ac63f8454eb99b7172f2940dce9
BLAKE2b-256 7447fcea41b93e637c7ae4043d0edc95d6d9eb7ae349c2a69c48263200588bc3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.13-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d9c2d08863c1894f5d40fb052530f2e97d8a3e66e131e470b9173db60e0bac49
MD5 0cf5e698b7a4493e463aaf8bb41dd811
BLAKE2b-256 d31e0db41eea0946f06f8d25ff254b350863e152a0d13cc450d5b06967c7e293

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.13-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 01c99a02274b1c6d9482af3610a0c365f355035c8c49de685603b938a0b9cedb
MD5 e28304022da120dcbfca381b40f8b7f1
BLAKE2b-256 8ac87a753e2da584468dafe3f6ac385403b24931f62abc8ccaa60a926a7c5848

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.13-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 88c69e0a49a84775b4c3d4edebf372b3a3f2826cdea52fea8388a7b81fb017ea
MD5 6e0a7a00dea7c29e85a3be7a6afe7a26
BLAKE2b-256 5792a4775eca2bbf577609014f9a6d24b2daaf5f41801ff86cb4bb1209127f05

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.13-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d6b6d719bb90a3db93c8568dd7252160b176dd99c4815cd3aa1fbcdb450b06bc
MD5 54a08fc9841f7b4731a31a714fd0dd9c
BLAKE2b-256 bbe729e7c3f812a78c0c29884bde70f43d7969b0ee1c3b5fb3a0039d6ab4193c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.13-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.13-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 b3e04e4bcbdee5c4993c3684df925454cdb0487174ea10c24457357bb1d074d8
MD5 a9123a47b8ffeacd03a65d06babddc4a
BLAKE2b-256 799607263c60e72cbdcd5f6f2b6f3384242684eb0011c8e6560d8492aa019500

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page