Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.9.20-cp311-cp311-win_amd64.whl (224.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.9.20-cp311-cp311-macosx_10_9_universal2.whl (284.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.9.20-cp310-cp310-win_amd64.whl (224.2 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.9.20-cp310-cp310-macosx_10_15_x86_64.whl (226.3 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.9.20-cp39-cp39-win_amd64.whl (224.2 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.9.20-cp39-cp39-macosx_11_0_x86_64.whl (226.4 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.9.20-cp38-cp38-win_amd64.whl (224.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.9.20-cp38-cp38-macosx_11_0_x86_64.whl (226.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.9.20-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 8324c4eabbe3818425f98002ab80f85c03b5bad368206ef511881de070bcd6c4
MD5 3c26d2b492ae3f74d25a5fac505afb46
BLAKE2b-256 5eae9f17187f68a7eab5627d8e09bc8ba707e7b54f344d6257202ca507faf3e9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.20-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7ef02237cc02ce68a3377335c3eeebdf76d66724717bb38a92599f83b4300b35
MD5 7fa34a08cb0304e14696f68e473b68bb
BLAKE2b-256 09712786541617ab56514eb3bc375b2ac27c426a26ca2c0996f3220f6537ccda

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.20-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 8d165f2edafab85ae6e945db68ca2b1fd4bdeeb007dd1fa6a1a7b2034b869665
MD5 2176264406c4569e257395211e55e523
BLAKE2b-256 748b42a0025583ae2d01c817a0546c9fadb23dcf94486d910d140054843da7bd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.20-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 b5462dc6ee69b600e272c791a26b0e237319572b6ba614d0f10f61848efaa37c
MD5 404a2631af6c229cb21b1fb51bcaccc7
BLAKE2b-256 422b18923f7b340231fca4b28122aad713103ce692c2b6132b4f5ac4de62adcf

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.20-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 e08b276ba89f48df5ddac879001081fbfd6544927bb40f12b6c468e614c1971b
MD5 fcdfe88e3258519584a7696297565882
BLAKE2b-256 92c72a10a18e0cfc47be5365c44348b373268335be9589bd7d0b58d064b0ec10

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.20-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 c8d7e086b4a3f5c5259de3977c90719555b540f8c34931a5aeef1386d9f39574
MD5 f633effe7b4e7cbdf65fc3f01f6095c3
BLAKE2b-256 c745d3c12c4f1a8d6e9c34cdf73a098faed608c563e8f24c27fab9a807f58bc3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.20-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 7ef108f080a353ce8e1956a80c7480b6af13f4d4c782d756449d8195e100c2e0
MD5 4622a866de18193df95949e11081b211
BLAKE2b-256 a3214598bed99337d9267d1fac120ceaf240cfe937a0ba10414edf27cc5b092a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.20-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 49f0d8a841ed74094b412cc3f9e4b86ba76f8b5a0517ee6bed4b6540db0f119a
MD5 f0a951d4e94807bc41e4a3e4c22e3087
BLAKE2b-256 18c99872b7fad8f34c49c27f337cd271a1da2566b75c6e3fb16b65826dfcefc3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.20-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 85c725203ce1043c0f1098fc838d7a835a88bbf1e38b7ef9942e459480d88477
MD5 28c16f3d74785c7e61f9a20e61368fde
BLAKE2b-256 c0a5e53c2dc5b9894d0cf2ad6c1f5971b166812e1c213811cbe7656a372241d0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.20-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 3d82e491b5db7f55e498e434ec65c3cb3517345394b003765d98cceba0cc25ed
MD5 2f39ad4b02504ae8c64c8f67b4e2c297
BLAKE2b-256 14dd4db7bd39e71054906263ff9177210f5dd2a2c4fc7451f70998bcf66d2d76

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.20-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 1d1ecfddee8a1658564441f25ed729b6bf82aea24319d4c9bbcbb4eebdba3b0c
MD5 8731a4d1b220706c4724ffecf70c5eb6
BLAKE2b-256 cb7dd238b647f81e52c813cf8255b79319a3069187a65d39fae1eb26597df499

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.20-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.20-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 5765d06711ebe5e1eae25a459fc002599c6ad3079f1bb3029af5136faf113224
MD5 536c8893e03fb67a1dbfbba47f795052
BLAKE2b-256 415207c11bbb4205192342a15b3f5d0f8f690e319d9838365b316106bfec4992

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page