Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.12.2-cp311-cp311-win_amd64.whl (245.2 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.12.2-cp311-cp311-macosx_10_9_universal2.whl (304.5 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.12.2-cp310-cp310-win_amd64.whl (244.7 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.12.2-cp310-cp310-macosx_10_15_x86_64.whl (246.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.12.2-cp39-cp39-win_amd64.whl (244.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.12.2-cp39-cp39-macosx_11_0_x86_64.whl (246.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.12.2-cp38-cp38-win_amd64.whl (244.6 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.12.2-cp38-cp38-macosx_11_0_x86_64.whl (246.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.12.2-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 3e8d52dec7b5261b636faa36d17f5780943988126dd5cf67d442ceedf35de428
MD5 e7fafcd99943e1097ed094c3950bfee8
BLAKE2b-256 18b34cfecba0f814da6a54f8215ab3c4c452eb9709362aa9419865d3b3e24934

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.2-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2337b357c0521c22ff1e817c36f42a96d1a6439fbcf9b067db0113dfab57a3ec
MD5 edac3ceaea0cbbde1ce9e78b59c98fae
BLAKE2b-256 c1d935eda0181118c19fcbddb229a3ba39b4e00347d58af605eeb5f94471d87d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.2-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 7454a355d4bc63d0b3561b5ac253f8dca7d0c9728edc1e7c804577a2d3634664
MD5 f08fb6fcc22d15420d2f1b3eda06dd17
BLAKE2b-256 81e46c1c8008ac6f1ca8498f23af9da125529cbff45db2bda2b0a231b02b69c7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.2-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 d147c551d883b7739229ff2dd08d47c57fe3ac7da2a3c7f98667d0e87b46ead7
MD5 a38128c534ed3a89d8ca1738213fcbaa
BLAKE2b-256 d43cb14972ad701b561d0266d76443c112ec5cd886ddaaea003f75f36a782ea2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.2-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 e7101a7ff426e1c2aed6691bcebf7b3e6907f3641ce24992436d8b25df08eec4
MD5 0c61b0bdc0d2a8a2b5bf2f9f27ff5aa9
BLAKE2b-256 14c3ff52779f75e8b7f28ce8f00f39a3e7db8ecc0f08dfb9c331dd58631443f7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.2-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 f6878aef4d88f74e11c5ef449e4bfc2d38281fc043b64f90124ef8bce73c0b50
MD5 f8174c1af396e1a0cd686744951b391f
BLAKE2b-256 7cd590de38c704f166fc18d78c0cd42e806b38d350fee65a1be9316500fa8987

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.2-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 5a24852c3d647057dfed81b5a2fa0fba26b990b705e0e167989092a7dc33ae2e
MD5 9fcb663ddb5b031b66704e8326eefbd4
BLAKE2b-256 b635243941230b2417c73b6abbc66cf1390fab3f7532399a831ef4a6ee810d67

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.2-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9406608b057f68a58af6663ea4c14bee85f15b096b39ac6eddc0fd1113ce3cac
MD5 70face27fea8e0204b8a51eb9dc5b037
BLAKE2b-256 5c6446d5b8f9ccfdde53c28f9885f9ab7ed31b3d70b959d9701a01928c0f4e07

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.2-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 da9e2a917ea8bdf9c43af181cfef10b99cdcd919df8fcac8a659508ecac9ba14
MD5 772d420a7571403aff0bd2b2067922a2
BLAKE2b-256 5ed13d8b816a6b64ab2e67568215f58ab37764a0c22dc47eae644dff927e548a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.2-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 e8e991a65620fad62a29a8482c7bb4e71e5a3a5aaa6547ca98703f6bb16dc51f
MD5 e81f1876f7af73606406b7523a6defe1
BLAKE2b-256 fd3581420a531eba3765b01c64037911b112568cd0f33c8b2373628419cc3514

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.2-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9e5c1763dced4fbc3678570ee4ccda73ce9b45648ba1515a407654801d1402db
MD5 3e7714d5a2959aa7ebf1d8a58c92a17b
BLAKE2b-256 8b162fa3a9cc2ccc19a11a8451b201f57df3d85e43aa782395566cf40c519a0d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.2-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.2-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 0c672208446686cb223438885076ca9ec229f3a7c5b95fc927ea947630b9b767
MD5 49298191de601aed9c614e6f47c22c01
BLAKE2b-256 1cb2d1def6c9d17098a12da312609e4dc7a7752db0debf9849f40b003ea5e9e6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page