Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.9.27-cp311-cp311-win_amd64.whl (224.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.9.27-cp311-cp311-macosx_10_9_universal2.whl (284.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.9.27-cp310-cp310-win_amd64.whl (224.2 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.9.27-cp310-cp310-macosx_10_15_x86_64.whl (226.3 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.9.27-cp39-cp39-win_amd64.whl (224.2 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.9.27-cp39-cp39-macosx_11_0_x86_64.whl (226.4 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.9.27-cp38-cp38-win_amd64.whl (224.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.9.27-cp38-cp38-macosx_11_0_x86_64.whl (226.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.9.27-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 8f6a2e866824c2546e4763968d9c785ae4032d14dd261b32e1f370800f791d0d
MD5 501cb5b0243e1c1b54613c97a9bb18e1
BLAKE2b-256 7d890b2cb3a61fa4696434ea3bbddc9558872ef42a5590eac6579d89c8d556b0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.27-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9e9f874f8ce77ff4f38706a999ea524597b04b0230963ec9507a76e4c3a99cc6
MD5 1e6f75c0a27abc8b1418d550682b92fc
BLAKE2b-256 181ab87d01d8c48f24398af443ba2e8b891099b7a66183e969db424f269fc051

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.27-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 7ac098267cb7882e66ec26cca8b2f65bf49b52fb7521b3da7fd3c40cd3df61b4
MD5 7deb6be1b7da8c75af1252dcec1c3992
BLAKE2b-256 1a45ab4a1fd392383e3c94cbd63466b8ed2d4092571657ff819525d75503aafc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.27-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 5dfedbad1aa5c669c53498de036c7b6f266def64fe666a84a08c4e78adc00d4e
MD5 041caf12a3b9e1aa0842be05a9e49e1a
BLAKE2b-256 b75e1dadb8331a84ddd9d2728501c00321afd3a5b46d201cd5a45123314061ee

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.27-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 edb0790d52a33e6b4e3dd05f0fa1741e417398ced6444c4b9d45370c9073a8aa
MD5 143194597acf9f6fde0ba030c68876f3
BLAKE2b-256 71cd8f59337bb8092b33d47c96f2753956828abd2b9f74ec092e0d4cd23202f9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.27-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 77afa9dd5336d0ed527d77463bfb61ed724a83850f6382e1da4a24c0c01d0840
MD5 1d46bd9a6516dfb6c8731c103040beb7
BLAKE2b-256 d408c467e1a5e96d560a374a05b1abd320373e9a0a3c980876e784b524d6875e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.27-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 fd74b4d31ecaeeef79378f42fb976f811f476ac672c1aa6c9ac047fe3e1a4a7c
MD5 1147b75d5796e6426256d916a08e707d
BLAKE2b-256 a3564e15f729d5c9740457bd29fa989b47ede47266c8455b3f78922019ffd35f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.27-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 88e5346e23b8f5a1471ad05bd243df1d972d70be0def69461b6b361f173d9b36
MD5 345a06ade06f14066402b997db8100f7
BLAKE2b-256 8454f209c6c5e1b035680986ee2a2ac39a6d9dec2f6abbf914f058b14520c47e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.27-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 2672d74b1bd2eecdccc2b29241493e9f3ac0efd34611b40aad46e74ba8564ed9
MD5 5130082b3ad5934e6facbd8afc8a80a3
BLAKE2b-256 07f6b404dd2add66e02855c20c7516decec55a8d99467b5ffa78766ec568622c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.27-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 04c963d291989570a8a619b9d601f711bd37c03a3571af6844291435aee49c6d
MD5 37d98c4b358db366c3ea4bd10c0dd71c
BLAKE2b-256 134748d8ed4fd83a739f30193c6d07c3a0fa51ceb271cbd558ae4bc0257e9bb8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.27-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 3522cc0dfd8dcbc22fee5a9237e4f37f9f46eeacb094fc0d37fcbacab684bc15
MD5 dc02c25d184392d2530356171f110e58
BLAKE2b-256 1ef2ac0eec9d070b0427db4406700e9cf2f6e1050cd0ad374c6582698a085762

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.27-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.27-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 7fe9598f8ce7ea1761ab7cec54f643fee569bc233d9357664840b30fb498d543
MD5 c37f9c1c1938bf003b6401a65fa0d431
BLAKE2b-256 bd0a36d286ba108a8b4bf1a90355def4ba3611e8e7958204790b3bf1ae268dbf

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page