Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.12.1-cp311-cp311-win_amd64.whl (245.2 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.12.1-cp311-cp311-macosx_10_9_universal2.whl (304.5 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.12.1-cp310-cp310-win_amd64.whl (244.7 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.12.1-cp310-cp310-macosx_10_15_x86_64.whl (246.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.12.1-cp39-cp39-win_amd64.whl (244.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.12.1-cp39-cp39-macosx_11_0_x86_64.whl (246.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.12.1-cp38-cp38-win_amd64.whl (244.6 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.12.1-cp38-cp38-macosx_11_0_x86_64.whl (246.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.12.1-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 702a9cc2f44107297ff5d3330f5199fb6ec0be28ce6a9c639f07db5290e4b233
MD5 e05435efc9bbb4d07225524b303f0953
BLAKE2b-256 1dd2e808613d12095d51a846c78cf07ff980d0b6e2cbbc665893a56125d04a8b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.1-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 132f49bd2990c866c93f350ed8ae2960a8c979daea55411df7bacb7dad6ec438
MD5 ea0512cf7a98cd7ac8afcfa1f2bcc983
BLAKE2b-256 cb7ec6db5d3c59d338ef092b2faf5d44474c4f84110119f0f208c338c255b7e1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.1-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 38741ae1faf2cd99705a223d4ef6d0205bfd2a5890a4d66f8e43f27af14ed5c2
MD5 d8940cc97d8282f30abf6fda5d496d3c
BLAKE2b-256 69529dc656ca820c175f9d3ea255d3d88c050fc9c1db3e8213cdf79efeb72a76

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.1-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 93361a8b76e778a6e5d4a01a84053b4913e0b19d70b0e7e42b123706f955ffe9
MD5 654f0ae0b9b257eb47f4621b9878f80c
BLAKE2b-256 80c86fdef326abf169bce7526170f917a941e5652dc13ebd2bab68e37546e012

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.1-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7a6e45ea2578fad359ace5f3b248dfbe52c0c1cb0bd2a3d60043899a689c0a20
MD5 fbb2cef35796f8308a22fbc2753efe58
BLAKE2b-256 a08a2b81fadfe69e3a0226a8fa2ca5c5eae3fba1adbc0a6c1f4eddd86949f475

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.1-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 5781ca619e4815876307f741bc439a67b26782b180dcb7f11b03cc293dd7c6ed
MD5 7ab563686d2899ef8b6d569ad28f50bd
BLAKE2b-256 8e5b9177bed3be6b0ba97bf9c0f55a79a075f7ebc04522dbc96860974067116a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.1-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 4117eb7bcba6a595e8b5562057831b277c363033bd9944536dd7c244054bd3e9
MD5 9597e2ac860307d1e006647cad9c0a0c
BLAKE2b-256 31c96ac38118a8eab74ff67c0a24582f14529dd4e41291d348acc9d2dae2f4c8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.1-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 430f60fdcb3fd37cd300615b2614b64d31c57c01ba2110e90aefa81f91bc41c3
MD5 4ada930fabf35b6cf7bee3a8243035b4
BLAKE2b-256 b0a48f5106973c19d204d0d02a24b9886b95720be1a04dc2e56c078fbb0cb222

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.1-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 f9c42bf2eb888e159fdad53a090c443d2e1027bdad0ca593794d1e0086cf6ed8
MD5 fb0886714f7e11d94ee53b383d858463
BLAKE2b-256 510abaf6cc9b9383316e3544c716b5f6dc5e686d69a6f97419bf3485f235fdd5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.1-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 27087e92e2e90e29e7ea7a581e31786db624846ceb153e47f8782ae53e8138f5
MD5 011b4a069d0d41d89ce035ced5e07941
BLAKE2b-256 03584001c4057d966a837b627071d7ff3cba7c18d1855d63154678c1e03ace68

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.1-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 82cf12854f99f951284a6d8031ed973fd10c59a77d62a448e5bfa521de1555c6
MD5 66c96793027abf3e9829855d03d7e4e7
BLAKE2b-256 011fdc3f367d1aa78db3003e084f84e70fd7e05111940d1e15f63f659623a191

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.1-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.1-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 1fcebe6ee5892537c1afcabce14d06aab70bdee8feb92eaa322ce3e21286797b
MD5 662a8e713e08a2d598f8800e9831b2f5
BLAKE2b-256 ec51a3c8df6475bd6677d53c3c4fd8a41bedd86dd68f08341cc5e24a300dee04

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page