Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.12.20-cp311-cp311-win_amd64.whl (248.3 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.12.20-cp311-cp311-macosx_10_9_universal2.whl (307.6 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.12.20-cp310-cp310-win_amd64.whl (247.8 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.12.20-cp310-cp310-macosx_10_15_x86_64.whl (249.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.12.20-cp39-cp39-win_amd64.whl (247.4 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.12.20-cp39-cp39-macosx_11_0_x86_64.whl (249.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.12.20-cp38-cp38-win_amd64.whl (247.7 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.12.20-cp38-cp38-macosx_11_0_x86_64.whl (249.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.12.20-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 a511aff949418393e2cbf7c4c1ed3cda9f05e24a67b733b73162862361e4143d
MD5 b17d49db9f0ccc945558f2575cf52754
BLAKE2b-256 f7dc9d5defe0a088fbe1d549180de3cc78cdc5f4f0e87a5627612c6b242c2c5d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.20-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 5b89072017b21891dee2395db100af3cc3a654f4da00baaf05713812d12ff34a
MD5 a06ae392afd5888163d152bdc025f1ca
BLAKE2b-256 b30c757f2799a29055854899ded66f0e7eefe5e6442e7f7fa59af9d88b3a8cd4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.20-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 3ff2c12e4784f3922b204c6485eb01f89798cb7b875c6d4df4c7fc907910ce90
MD5 4450c6d7bf63da01651478d44221fbf8
BLAKE2b-256 9ac7e0e6889ffe557e6a3c7461d590c06a107f937bb802fba2b18943fa6d5d28

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.20-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 1f6451ab7626ce90c4034445822ec56170c8e5d6b1b90415cd09edb812cc29a1
MD5 e6cc90893897a5e619a03c1b626951a3
BLAKE2b-256 4739c9bfe168ad21599a78ae8c6d012afe84727cb8388b5e43ec1aa4a6cb5659

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.20-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 88ce8b7fff9878d80efa3d6e1c05f64817060f7556edcbbcd85952c423894ed4
MD5 e6d407b5264f317e80e56f6e216a00a5
BLAKE2b-256 19549bc3a2d2836d97c5c8764cc8084024127226f7a844fad50889d08f3d5ce1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.20-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 3fd8bf8ac02b9bf98b29113ee1ebd36d89a496cc8822475e47bf4d2b6e775ec5
MD5 ea33ffaed39689e353d69b53fb921653
BLAKE2b-256 aa7278816fed8d5527070db7a76948ab76f7b04af6168495e461d0c07fda7681

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.20-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 aedcececb7d19e55e11412ece8e33463455988c100e973c16947d6b8ea76e347
MD5 b135c27704a2091770051ddd29b89054
BLAKE2b-256 a150be38252ee8a4cde3e078fc392922fa586a7f0b0190c939624758eb74e6f1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.20-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 05ceba52766d6ec128de4d79813ce96139c7955f48879ea406f3f1c92ce97109
MD5 02408d8d5b867129136b588acdfeebc8
BLAKE2b-256 4ea3bb12ea26abbbe4ad01bc95aeee7db1982d34a0bc06a6ada52dbe49ebed24

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.20-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 8ee2a85c005ed5a77efd66f0a28f0374076247032d550a50dc8fd243f4f2c5c1
MD5 523a60108ff9b09ba7b96dba5458099a
BLAKE2b-256 107bdfff83f925aefbef60adc4b1d753ed2a5e62fc8b73ce19f9d222b1a55088

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.20-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 8482ad2d578133a6f1c13fc813d31b147ee5df5ffe78ebfa5ff90f24a79a695e
MD5 8e26dba834b058c668d77f9403598ec9
BLAKE2b-256 6f802c796d1f5eba4259032227055bcb8938c2eb39a5da2f84a1cd101d3fd625

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.20-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ff6ef2af025d7b12e07a5918620b661f141eedb4a721ae82844a37a6f6298658
MD5 0f67f2f84d073e27feb4078d7fcddf20
BLAKE2b-256 1acedcb82c23dcba7038db5dbab019ba026449e11c07ff8b4be9b95315d6c7e1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.20-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.20-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 d4ea273730b4e29d993d49ee4ed36d0d9c54507b678a26d5435853d9cca82958
MD5 93fc6010ac218653ecfe59e0863a0bae
BLAKE2b-256 3370bb8abecfc879dcbc2d6162e297621d52ef8cfef7166ecc041458d1fc9049

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page