Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.12.9-cp311-cp311-win_amd64.whl (246.0 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.12.9-cp311-cp311-macosx_10_9_universal2.whl (305.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.12.9-cp310-cp310-win_amd64.whl (245.5 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.12.9-cp310-cp310-macosx_10_15_x86_64.whl (247.2 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.12.9-cp39-cp39-win_amd64.whl (245.1 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.12.9-cp39-cp39-macosx_11_0_x86_64.whl (247.3 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.12.9-cp38-cp38-win_amd64.whl (245.4 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.12.9-cp38-cp38-macosx_11_0_x86_64.whl (247.1 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.12.9-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 e1aaf31f8dd069e17a63f747b4191501ec452916a94a0cc8941c637a922da36a
MD5 aadd6f9bdd3e05d0a4c8825ff6dae152
BLAKE2b-256 c3a93e98c068bba2629a648d2f3c4c98582feb24f61b2454193308b6b88c2bc4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.9-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 aa1b496f3f581c8385b1ce3280e36582fc251c7cecf473315482766fb6f60c51
MD5 5e151895ce02cd880135e3f0f5447f79
BLAKE2b-256 3be7fcebb55ec9691794f3be0d4cd3321fede5a2fd456d6c3306dad73a643535

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.9-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 8558cd502fed67aafc7be0d1d6dbc0122f65fae5e14116c43fca865c83d8bc84
MD5 e8c6c5a84523a9eafe24f2cd7223daf9
BLAKE2b-256 1bcb09b8396587bbe3981235ecf8fefcb64b49cb13785055381e57b3e552a159

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.9-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 40b6f3c626edca1a5a8df7fcc38103c7db886512feeeece324ee6635644ac671
MD5 9440e7785318fbc9cbfe90bd44a1d686
BLAKE2b-256 d77754d63dee77ff8d1ffceba1f7d5a3bf58ee82a86632cc66c98a88dccbf303

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.9-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 a05ea89dfcd28f2d6a9a4299a276f8d445a2c8a2e20113d85ced46c22e0edf2d
MD5 32857e69b452fb61901cca4182fea9fa
BLAKE2b-256 a0886e2df34677ff6b9e3d3302d8d989c93026acbba9f1aefdbd0ebd9ca4e060

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.9-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 34afa87999f367b4eed128973ac11894921ed22c26780b27fc14a105b2f5d542
MD5 c86ed62b3148ad3d18459abcd399aa87
BLAKE2b-256 91fc910dc701378488a460a314eb28e0c41552c29e2f2ce48f08a4d28be7a660

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.9-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 2914eaf44cb5fecf8a38bd15c1c6e98b32ce7fe604b22b52d1424352f4aa7178
MD5 ee8beba7e8d5159be954eede157ecc02
BLAKE2b-256 abed1d4cb4128c372a31b2e4fcb4ba9ea3559a336b68c4021b4086162c26ca62

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.9-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ee36ae063067cc3e62cb9dfc8239eb7e50260e898985dd2cdc58f538a3f40a3e
MD5 0ea5722f8a94336610964c29f4c77c27
BLAKE2b-256 475dff8a6cb8a4af0d2104e138f49c1223f6ff585389c7c85a8856036aa4243b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.9-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 8f5bac975df8040978646627bbaede2dc111a37950af8bade1d1d3e4abe745ed
MD5 35af721ca42e869f578eccbb16896939
BLAKE2b-256 a2fb81253d64bdedb7bc692ac803366a3cabc27c51b3f0d636743d168da2152f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.9-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 3d3a28731442b0278735e7c62ed36638d57c005841c6e4106cfd7805993bd456
MD5 32015683d60f88cc7551c52a07833f52
BLAKE2b-256 311e2efd5e46188fe021da1ab87b748e8c962ec2d161388b23098dc00f8e8bd2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.9-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 54dc600d53556dee15c91bc2b3e7cf767efa907afe51d8e51eacfc7a28af7c96
MD5 851267614ac219e4a2499d6d6940217a
BLAKE2b-256 0780f5d1bc98af6f6e3c0da680b0e114ba37b1dc5f35b1e143279464ad2c6381

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.9-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.9-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 fcd48e12edc78113d5abf74d1deec916d893ecddf7a76abd6fff5f35d7976947
MD5 03ac8b28be99f41da0dfa3d1db39229e
BLAKE2b-256 8d4a4daba129878f64b060e5eb1a2c24d0097f3c8922abf565f4be1b2e582256

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page