Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.9.9-cp311-cp311-win_amd64.whl (222.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.9.9-cp311-cp311-macosx_10_9_universal2.whl (282.6 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.9.9-cp310-cp310-win_amd64.whl (222.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.9.9-cp310-cp310-macosx_10_15_x86_64.whl (224.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.9.9-cp39-cp39-win_amd64.whl (222.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.9.9-cp39-cp39-macosx_11_0_x86_64.whl (224.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.9.9-cp38-cp38-win_amd64.whl (222.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.9.9-cp38-cp38-macosx_11_0_x86_64.whl (224.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.9.9-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 24476ed26864abf8b92303fbfc05226c15bff45ce10f9256d7b2acd3d4b92807
MD5 9267ed88ef41b571345ef7d248656d4a
BLAKE2b-256 283342328543f0c4f27c59124182dadf5d07065e712b67adbd36190f32c14a12

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.9-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 699451869795f47cfe62a51db3dc356827b4e8a681749ec9c8754f7849efab8f
MD5 4cf47050da1ae42563bf6721062a4333
BLAKE2b-256 22e0d2cd97524aa4bc6db6e583c8c556f73177f4e893cb48f8f8c6498ec9195d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.9-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 983214cbb77061e209df53a07d2606cbaef114565ddbf54e92ecb5bbe705c97f
MD5 6ea71fbdd2a524e84e461572b1d2a794
BLAKE2b-256 80526048e2452f35f83ff8d6880533cb3cccc32ae042c13110cb43b2a490eae5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.9-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 b76ed481523c637e6dc4c0f63d04392a3933329c1fd5f9e59f223777c6965aac
MD5 975c39c2465f9a809890b6e2864bb5c4
BLAKE2b-256 08b8223cd9fea7bae1fa2a625701acc32875872364d4292748bd85b1d92ac6f6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.9-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d7c91c0514a8edd613cbbcec47d74fd830e328f72c1b657bcf7e21939d707d2f
MD5 6423ff6af5456d74b9ff241d4b0668dd
BLAKE2b-256 c9015978f0dbdc0a039040500525c55ccfe5e2bd5f551dbda3e6c045a158fb30

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.9-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 bc0f3ab7f9379290b3e6855a915e39c848db5c98eef16fe4f652a13bf9a98cb0
MD5 c49d581fbb23007a2179bdf98dec5fc0
BLAKE2b-256 6fb8d38f299008b1b929e7c3fb17870b4204da2abffafc4fec249688ddffb831

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.9-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 151ba0d26ab4a1865c5edefb320e611631fcd15f7a5ab5549d13b05b830f3239
MD5 a33329723315db9d8cba25673481e7b9
BLAKE2b-256 fb81a4da8bbe57238870f54161f65f738e385b5769acd8a5702f3cdd36443cea

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.9-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d80ffd9b557ae0f55a1ce4b20a843a1d6982ebdee4b4c705b037dc4a95e5cea8
MD5 4631257247fdfb0b566e397f15b4103e
BLAKE2b-256 9a45c3b45b8c379076739432933cee6ceb50ee40c31b1c10a296a4f48d0f38bd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.9-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 58c9c45c848f47c8d4efcd7864f906099521e025c9d70d189097b789e058d368
MD5 c48bae0bfd6b1c75eee040ed4d9e7d4f
BLAKE2b-256 89f381c3cd5e0930ca9ebeaa3824e26973a1b69a88e6808d96fe11c56007be89

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.9-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 ad1b9d4ab2aa709d6fa8c08072facee591dc8199d2ac0a7204e754ed155e60eb
MD5 62dbf3b5b70fd9f8b010e85cfa9b3028
BLAKE2b-256 73689015fa07e1a810a601228f52988a7a3aff8404c492a69089eeb606ee1008

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.9-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 fc07923b56f37f1bfe62aee21795b17d9fb84462f765e531f679bcf4a850ba4b
MD5 e80725da840dbd1b7bd9e8cc25f2745a
BLAKE2b-256 e0e8808de8a5810781647406b522d119671b405934a0e8f2206957088feab322

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.9-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.9-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 8440340b4bb723a12fe08217a30719154d896ec09db083b379b10799a4a7b1cc
MD5 c6a461b60b54eb5c2ea948fa8ae3faf0
BLAKE2b-256 177ab73b26d1eeca23f4aeb25ddb38326bc3a0beac186857847930f41a2703bb

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page