Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.11.21-cp311-cp311-win_amd64.whl (241.1 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.11.21-cp311-cp311-macosx_10_9_universal2.whl (300.5 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.11.21-cp310-cp310-win_amd64.whl (240.6 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.11.21-cp310-cp310-macosx_10_15_x86_64.whl (242.3 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.11.21-cp39-cp39-win_amd64.whl (240.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.11.21-cp39-cp39-macosx_11_0_x86_64.whl (242.4 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.11.21-cp38-cp38-win_amd64.whl (240.5 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.11.21-cp38-cp38-macosx_11_0_x86_64.whl (242.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.11.21-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 637f2b0c5539cd5944a78ac1f6d299b1c6d56273efa74fe24c8d63a54b31d3ed
MD5 ec71fb8832f2c2c6a335ee6db24061dc
BLAKE2b-256 20ea61634492f0b2e44d52153cbfa46dd98cdf772725d4a1cac911fa2b825f69

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.21-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 450e656f8e582ba0bc3607fef8007b6c597c4352c94682376c3dce7565f56286
MD5 99ead68fcae72f1584922fae9ac35310
BLAKE2b-256 2b76bb913c1c8fe2de69dec6e2bd26aeee49ab852c17d3992f964bf8c028767c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.21-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 446778bcb0647782f1807d2a25bd4d83563b90686bd110a685fdb817c43b66f5
MD5 4346dff2a788fbaabb9fc00a77c1344d
BLAKE2b-256 cb348d812f1a79a5609a3fd28c45b516a4208b7fd462f41a1ca2620f406f056d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.21-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 030c3bd091d8f789dd02b4bc6973228df2687c4c5d1fe1146d2ce6905517ebed
MD5 69a07a794a97052e30d45793266a0e9d
BLAKE2b-256 eed955e33f1c563275f17c0a1ffa210c6a2132f7ed7664e9d22de1f6c49ac0bb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.21-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 abe4bd91119d1aef72a2749bbbf3660638de6991331209a4b1a76e20122268f8
MD5 1d91d0a97509c7f22bd62a2654345dbe
BLAKE2b-256 9e9a864156f27990f92d765d41d68322736ba948ea87b90368b286a10062d668

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.21-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 7d4ae2c8035936857646db54784a6425ef37e1e72a1e175ce5da9fc4bbdfce36
MD5 1b508f020a4733a28e0094b1c9c0bdbe
BLAKE2b-256 0ac3911273c4095fd8b159e2dd5a4ce89f4a0b548dd9b9d77e0d4ea70e00b87d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.21-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 74af8cfd59c621c8f51e722dc126e14b48233c079475a2eaf5cb043f30a0e1ac
MD5 d9ce886db74225e91a4415139a311ae6
BLAKE2b-256 ac35ffef6e3334e85eb757f39a5e70e19886170c6df312a7ed3332bbbb7f0318

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.21-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b8a0885ce172aea498e39151e4cb6f88f72d9cba31c1a49e426b28134dfe65ad
MD5 668c6afb5b5506b39b431e43aeab9309
BLAKE2b-256 7218fd9dd463fb53caa4bc7a2479e6e3efc886bf8305a1fc75135aa005f5b87a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.21-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 957f5922dcba20c0e57eb8bca69985524c80a1d272e9ed9600cb6dd27c0b4d19
MD5 e63c15dd260b7a2b559ae5b80a5040da
BLAKE2b-256 575ca09203f7cd2c4d2d7a0df04552823df910f2110c1f079fb90049e4f4a616

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.21-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 2fa9b1743a0a7aa8090120484df042ae17fa5dd62e5e571fb41de8276b0655c3
MD5 db34aab50956de91ebd179a008019a99
BLAKE2b-256 53e809fed6245a0d657acbc933eeb2adc84073da2611f0f327af8ed071fc8f6e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.21-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 762a2848edf7b298ffc2e011861da03d31401921dc1308028df8fba0ab6a74ae
MD5 f8eca83482176f2a16c9fb471eb64a5a
BLAKE2b-256 a25b5d642ae22c1830af02d4f6fedea7138ad4be140e074dbc05d0c7ae542fff

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.21-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.21-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 d6a0c5440334b646b28785bf0979c3d07a89389aeee497f4736233083abc34f3
MD5 bef1e5c09bae6fbb78b3acda04767575
BLAKE2b-256 c939421bcbc5d698836fe3962bb0807888a6ee2d65de7e5dab1628c4f9515179

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page