Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.1.17-cp311-cp311-win_amd64.whl (259.2 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.1.17-cp311-cp311-macosx_10_9_universal2.whl (318.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.1.17-cp310-cp310-win_amd64.whl (258.7 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.1.17-cp310-cp310-macosx_10_15_x86_64.whl (260.2 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.1.17-cp39-cp39-win_amd64.whl (258.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.1.17-cp39-cp39-macosx_11_0_x86_64.whl (260.4 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.1.17-cp38-cp38-win_amd64.whl (258.6 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.1.17-cp38-cp38-macosx_11_0_x86_64.whl (260.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.1.17-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 e75c664138a231429a4a63d45493197b42c817525e326dc391687936b2759ae0
MD5 6849107479b25786ef1c863505b8366f
BLAKE2b-256 f6c2cfdd12da4aab01d7a67b832aa29ec18030e535f8a48e86e1611a8e71f9a4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.17-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 14bf508b4d223daf46d80b797ceeb80649bca193d82d6dc7ae4aad0ab0304cef
MD5 2dc32f2ab062fe0234f8d561ed3c0e01
BLAKE2b-256 25db474163610fad54689fd0ec0422ddd4afc48255a15900a949be75b7134579

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.17-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 30931726da66d0fab59a2f1a412b618fc3cf55f4ea394fa09b7d5a98dc8ba105
MD5 f79c644e9293db4424d23031fc76f8cd
BLAKE2b-256 99a0e7d2a5eca84cc21bb4258058b7a56aa1cab436cdb7e5c4b550b9de83f499

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.17-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 a068536844a386ff090e260a1eac38d2618205f559aab7404897e32f8ce451ef
MD5 b6dc4c30e0ab5454c6ee9132082b50d6
BLAKE2b-256 dc1834fc8f9297925ea9d70f32f7e76fc4232c436f054d2647c82b4c0794631d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.17-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7a1a66d16d2dfd3ab3e90c35f36cce8eb00ff1bd238374f03fa9e6c0ed5b8db8
MD5 55f3182bf44808294b1c3201f2c0b058
BLAKE2b-256 7e6722cc0765998cd7456c89f9a42318c561dc1c2fe4a2a89d6f5e8bc10861fa

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.17-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 5de78701585e8d05a7beed15b8ae9feb8fef2a00967899bbeeea262f23f20b78
MD5 d0b8e894b6a0be8b65a21e6ef826a842
BLAKE2b-256 827cc7b67a0e1c53e42cd7a3c4baf545a053a41c5b8323bf4cc38893e4939fb6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.17-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 c6c0d519173fa085f0d80c78d37779c209e04ad3b067a048a082b36a7ff9297b
MD5 4c7d5272cb93b6ca971d1f4519d15a8a
BLAKE2b-256 fbd41fbc94929e1980c2566928dcf056503e7bb6affa50a703493f6af1f61254

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.17-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 1f19592076a7003415971d30e16a8ae4e96a5957aee30e061fb8f733fcb9ae90
MD5 1084e197e8ecbb6d9fe5072241a865a4
BLAKE2b-256 96ae20ee90a84a3a7ee1d3a6c1a3111857acb8c16d68cb107ae5eca8f4a13c5c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.17-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 071b0be2a120d5a4dbbaf5bb88db692badae8bac778ebc524ff9786231fa71ff
MD5 146e6049e4e46d90a376c6a15835daa9
BLAKE2b-256 d3cab8a8049866a54bd5a6c543519e4f3ac39f012b14f508d88cb9b4e0c7d1d1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.17-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 6133f114ba1183488aff502fe2c136485e7cb05b68e88f2dfb4a41edecd751c9
MD5 28a8325b074469ac058707713aa45e1f
BLAKE2b-256 931d74c3325c59171c1846782611f819bcc4aaeb40f1e9cf57e5b5eaac45cda9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.17-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 f9e40caeece9dfd0be68153bec4219211f72bf1af2a483f13cce8983633fffea
MD5 8a297eda67634759e517bf5641567841
BLAKE2b-256 30e7c0f33bba59c154abb57a7df48cb46295389b4452e492671df2037c93e9e0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.17-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.17-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 0df99a742acaced5b8580fc9de5280ea02bf0ee4b490cca789d7476d3b0f9d4d
MD5 da2e8a9933ebafc45cc378d06069e41b
BLAKE2b-256 6bd40ec23d3231d68fba598228f0d373f6748fd129bc01bab769a2658ff51ad1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page