Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.10.26-cp311-cp311-win_amd64.whl (225.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.10.26-cp311-cp311-macosx_10_9_universal2.whl (285.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.10.26-cp310-cp310-win_amd64.whl (225.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.10.26-cp310-cp310-macosx_10_15_x86_64.whl (227.3 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.10.26-cp39-cp39-win_amd64.whl (225.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.10.26-cp39-cp39-macosx_11_0_x86_64.whl (227.4 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.10.26-cp38-cp38-win_amd64.whl (225.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.10.26-cp38-cp38-macosx_11_0_x86_64.whl (227.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.10.26-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 3b597c3ec0a618c0dc6fe8a68c125a5423c0b4b994412b2cf98b6022ceee9f1e
MD5 76397b5fef544b6f41b93660ccd3a0fd
BLAKE2b-256 0fd5869b8a88b176c64dd7e39999dc6cf25b6416054f480e706b5456b0718bdf

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.26-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 4ccd4551105a981dd3be1e51b7d08cb6b83527f0c7e6bf24f0f45dea64bdee2a
MD5 d56085d0d568ae903f7756d7072ddf2c
BLAKE2b-256 0b354be16aab35eaea9fbbed3105394c16fe6a4da33675d33a70f3bd662abe3b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.26-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 7c3379b11b47d8096790390f95b0670218908f6cdbaee33c7e37a162912be30a
MD5 16fc413989c4710cd92b999a7411aa4b
BLAKE2b-256 14e3709e114224606f37fcc5f20f5c761d8aacf1500656fb7aefe4114a196da9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.26-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 a2a97c278e26751824b9449938591a19b997be8292eeaba3add6e60cf4062e41
MD5 e055556a21086c20816e35a5ac37b36f
BLAKE2b-256 5a1f3decd563212e844a22f06391f321e0566176ad6a33a15f3aeddad23389e7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.26-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 683c49befa7968a62f1a0ee4194dd5c1ad2cc2f26658f9ea7de24e62fdcb5c2d
MD5 faeb61b1ebdfbedda1f02170f0003c17
BLAKE2b-256 ed205de68303edc7a97d3906977321051f2ba516ca04e02568e6657e7d3284da

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.26-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 76038a20e24a9ebc2166016eb3ccf980c956ec292f256ddae2db37efec24d43f
MD5 490ad9fb03ae5634c3e12886c4122f78
BLAKE2b-256 66b8cc8453da98126259b7fc1484c895bfa9c25cadf6a5a1ede61319bad2b893

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.26-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 11b51835c260da0ff6502a61daedbc0c12dba75402c1dc404e7e128a3896af23
MD5 df9e70434962091b4dc3428049961e6f
BLAKE2b-256 658b97998d49ec9eb3419725f68480d7644dab85ab0102aaaf561b1d9e8679f3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.26-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 14d6b9cff0e4a694035efef58520f64c4e26de7adc9dc43b008bcf2426c31ad0
MD5 e46fc6bc4e3e262706f5e29248be6ab1
BLAKE2b-256 38b81f59643b84b554b8a05b007cc71913ed801f0f21f70415315c86262ca420

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.26-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 7e4e3f643988de2e773db6adb227ce02f943d51a679690f8392391f4127d1b57
MD5 7aa4d8f9e9cce804937b483c51b73323
BLAKE2b-256 05de649ea7361f0d89d1f9ce84eafdc2ca3516f236fc2cf258739d344ae760ff

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.26-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 a9caeefa03f2160f15cd4e1b03e31bc727fe5c01ed632de304535bc01628fd22
MD5 b66aad213bf40eadaeb9125ee79aceb4
BLAKE2b-256 b8273a970eec86b26a6be66a97863bc8e432b8daa5243165e115994e7557c35f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.26-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2dab362b6ef800a4f5fbaae194e5ff347cd024957778af0fd8436b189a62e5fa
MD5 efafa7574026e5a9976b03916ac7de3d
BLAKE2b-256 6b45265d95b3cca80a5be8189d2d02045ae26599faac0572bfded27f60586309

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.26-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.26-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 d5b98a643f4a5de45a6e352c36394e2293a63335ecb509d30c8fcfffda58e84e
MD5 fd1201c82bf4fd150567aee9b272c998
BLAKE2b-256 468ece496d01f183c9d8099d50c89484e7f53ae5c9e42223cdb14448253968ec

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page