Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.11.18-cp311-cp311-win_amd64.whl (229.3 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.11.18-cp311-cp311-macosx_10_9_universal2.whl (288.8 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.11.18-cp310-cp310-win_amd64.whl (228.8 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.11.18-cp310-cp310-macosx_10_15_x86_64.whl (230.6 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.11.18-cp39-cp39-win_amd64.whl (228.5 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.11.18-cp39-cp39-macosx_11_0_x86_64.whl (230.7 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.11.18-cp38-cp38-win_amd64.whl (228.7 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.11.18-cp38-cp38-macosx_11_0_x86_64.whl (230.5 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.11.18-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 1a1dd119e6f938512839fee6a5cd1f39757f91a5107585a5dbf85bb5829b75c6
MD5 b5122aa8b7c8410f0b31cb50c8e9791a
BLAKE2b-256 9bccebe010b04f0f69138500c7ed735953ea72b59fa433ee3a440858b54b2770

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.18-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 64b37b4d4348fe4dae7e556fd794bb91b1cfd1f7bec864abcfbf889aefb6ca40
MD5 6c7516d04cd48644f8e5eb92068c65d8
BLAKE2b-256 3e879f4ca57737f14711ca7c00002693d008a588943f5159f70bf39512717b2e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.18-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 43d2452022c2ccfb653bb9f710173c4283056ef7bb6b4b23e4da490837bdf9f8
MD5 effbe180b543cd5314d991f231c0b7a9
BLAKE2b-256 3793c93bee1aeaf524f5b245baed9a06b5d082576e303947018cc8ce0a223929

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.18-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 83cdb1a23ab63ad7a08976d6bf1bd891d99df61333edf0473eb37bc150a1204f
MD5 c3b7397039a94f18dcc0cd47081b22e4
BLAKE2b-256 e6dae7f936ff459fc5d724c8a95e81fcabed0ecf1f733dc6565a82cb02857878

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.18-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 45a88ffb0dd72190be15e1c81a3e10c1e0ea2498ab3931220a2bdda8b6dc6ad6
MD5 d47898bb958a8a79a264939ee639387f
BLAKE2b-256 6435e0616f4681a70b24f0248e6e75c1f5f10cb5b644101e61bdf50c88f566e4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.18-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 ed33b2c4fd3aec2e85673aa8dbd56b4f500a7ae20feb818e9378cfbf71490317
MD5 024203c87ee606ff2176ede20975809b
BLAKE2b-256 5bdbd8fc3ea32bc105200c511f482a15e3cfbd293e30411fd853f0b63bead708

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.18-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 c6ea116b70c6392a338cbc1fab65566069acf83c510d3f134bef4ec89e166179
MD5 6f27ca7ef1ac67b0814f7366ae94e6ef
BLAKE2b-256 d72e4bccbc0daba9024f3ccb4b1a2902b08a18804b321b4f8cac8cf2b0cafea7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.18-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 059366a34b94772e74299597917a0a6988ba020e9adad46a95695eef58c41eb6
MD5 be2e8dc92469e2682c6a1d5f05bd5429
BLAKE2b-256 b073d21ec5d169fc218cb2e2905c084f4a32187a0e06ba2624498c837c2657a6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.18-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 999e9356ed99a74240e00d570f4654467a012d037adf55431a5aff649cfb8ba8
MD5 8c54ee4f9d821b8d7132ea4b19af7d2e
BLAKE2b-256 1e53345bac122afd2d7c7464e344ff020834c0e1ea1c1f6962ccf2be53c1fb70

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.18-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 7885a6c1202ce1b3a2b5a9f406a102fb8d2a5ba361be229243ac17eca94fe1ae
MD5 40dac0145f8bf0677e8aaebb5f18728f
BLAKE2b-256 e230b9b8bcdbd17f08e41f592fc5d6f828f309c410fcd69f913180159d3edf96

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.18-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ab8e0bd25dedd406efd1e110b1c7510d5d55ad44be6d6c71ae4fe6917c9fc884
MD5 30231ebb61f61a20ebeaa9eaf0718b1e
BLAKE2b-256 20d76c34e23f4681d57139efce4add7dd94c588a0e01c65d33fd7e4c438e6496

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.18-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.18-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 813e776354d811bbf66102121f4cf103a99992c6329e464301dc0936b7da8613
MD5 3fc78e51191aeb8df8a3ed8607871924
BLAKE2b-256 7a9cf2eabc70cdf2eedd4fd862bfa2c5266c66b9a4e0707cc510f7bda38e3fff

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page