Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.12.3-cp311-cp311-win_amd64.whl (245.2 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.12.3-cp311-cp311-macosx_10_9_universal2.whl (304.5 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.12.3-cp310-cp310-win_amd64.whl (244.7 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.12.3-cp310-cp310-macosx_10_15_x86_64.whl (246.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.12.3-cp39-cp39-win_amd64.whl (244.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.12.3-cp39-cp39-macosx_11_0_x86_64.whl (246.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.12.3-cp38-cp38-win_amd64.whl (244.6 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.12.3-cp38-cp38-macosx_11_0_x86_64.whl (246.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.12.3-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 ea6f4d031c4525ff63972226bac29509bdda5faec99cd04cb62947391ed4f7b6
MD5 035c342e8fcb83a253147ccd5280e099
BLAKE2b-256 4a3819bc9d8f1d80a8c3006c01d895bf6bd8df396b5bb5330dc24b804c0dd8f7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.3-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 e07ad9b8c748302635c9420d745164306134da96fb3426998f3075894fa23f99
MD5 b26ebe4c1516eb8c1d2f63279bf9c475
BLAKE2b-256 78a74691cf2dcde14682a56924232c832b9876121731ba6a0ad55cc3286c3161

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.3-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 4c65c5e9bde4fcde1adf923c602be7d8c87792727fd4813baf379dc9c9fdfb21
MD5 8a4641bdf3b8ee2f1afe81a1e574de84
BLAKE2b-256 661ebe1b5f22ddf119a579a753f99f98e8dd8cbd201408951448f377a1e458a9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.3-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 592d7957065fbbd3d37415487493b532a082bbc10a742a3ae7a4662600e268cb
MD5 626a2b08e199dcf7ce84de60b5b44bc0
BLAKE2b-256 89b0be5fb6cff314dde6b818109dabd508dd600bcabec1b24ede4b55dc8aefe2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.3-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 eb616877af5fd1d2057d8ef810b0698c93ffc797eca7ca825294d2cf50459d77
MD5 ec544c9bf4132dccb8d5c02667ac99c9
BLAKE2b-256 8a6af8f7940decf69384c7e5ba8143cf942c8645d57d9a0171c252a8884f4b60

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.3-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 94be2d701f11660e1bc296a531f558048b8ddfff8bd19c584e8922ad84ba7246
MD5 93b744d645118309dd7cd70572452b28
BLAKE2b-256 694cd60a9e582758eb70567c3a7f5dd404d7b739b7460d0b275ae524dd6459c7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.3-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 6e40eda256272aa15b91f219c4a487b80a00424fc9d5101653a38c0730da142c
MD5 fd9b823bfcf4fa36563d148060f18bf6
BLAKE2b-256 ba4fa45f3edfbf261bc829fa2273530524c0ae47b36e29602fcbcd635b676d75

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.3-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 0aa1dca1e992e70dedb01f522b53bd1529588a147db2a38bc7969e99bef89f3a
MD5 1192bd2b47e8f909cb4cbee5538fa894
BLAKE2b-256 ec7ccdc9a8a5f66bc2a7a9fe458857411449304acc9f3fe0168da180d21ec52b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.3-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 14afcb3b9e28409124ddbc0f1c8970a6f38a852c1242fb536f1f5f66d508b549
MD5 15d54168467dea0b8f434ac2aea903a0
BLAKE2b-256 88c1495b10f5fe80b148dc9841d5b385e437e997c6c4dd413c22eb3044c78922

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.3-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 a4314c630d27b2d56fee70be38ca39f8f92037014779b67e4b97038de5f3fb3c
MD5 d3d877229ceeab643d633051faa4324e
BLAKE2b-256 99b68e785026e438c95be80ca1bbc537d965a55b0b9e72abe54cd6c99240a781

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.3-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d6e8e58db6199c2bd4cbd3c8477621febe34e28cf5955ea9398e3e39269ce901
MD5 37b1b3c178222b7d4a52273e4d8a364c
BLAKE2b-256 80b96e4c789c2ab813eb049c28b9b408c38a72e5427ced6689d59f5d6707e151

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.3-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.3-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 46b7c4ca02d1478ca387a49132318f33101fcc91a03a5d52072620dbf8f65363
MD5 3673fd64be0e5b18ceae04b2f53c0728
BLAKE2b-256 90c0549b037442218a65976dafb03f3e6a183c708483d2c1b8474ebb3ff03b8b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page