Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.1.19-cp311-cp311-win_amd64.whl (259.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.1.19-cp311-cp311-macosx_10_9_universal2.whl (319.1 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.1.19-cp310-cp310-win_amd64.whl (259.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.1.19-cp310-cp310-macosx_10_15_x86_64.whl (260.9 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.1.19-cp39-cp39-win_amd64.whl (258.9 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.1.19-cp39-cp39-macosx_11_0_x86_64.whl (261.0 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.1.19-cp38-cp38-win_amd64.whl (259.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.1.19-cp38-cp38-macosx_11_0_x86_64.whl (260.8 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.1.19-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 760e51dd4a781124a2aed3b37d54a3bb9dc87e33df3f9032f78d94953bf0af4a
MD5 5af22207ba7c51264087d278d513443d
BLAKE2b-256 3a2302e0adf749ec5fab947f17983333d9f447e0504a8915639049c433366bde

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.19-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 3ae1608396daf03b3f9963c175bcb929b4d07401ad6b3b08dc236e65442898fc
MD5 eb79dd854682de4acb4d60b9d13b20b1
BLAKE2b-256 f50b41e5795f433ce203f1269a949df6339fc48a53c742e9535e3e5a5460bdc0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.19-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 b68abcde904cd5b462643480fe3636a349fb2849131f7c31db5ad3cd62622934
MD5 8b48a2b7abeb628f4d24c9d543cd9064
BLAKE2b-256 a90cbed8cfd39d3f704c7bc335972d7d1d6a33cfcd65c2a909fe824b8784c462

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.19-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 a1becbb25b966b816b4942943428587d1efc51e4a07d35a5113c544a7ee39d22
MD5 716ca5c2bdb9c71756e678f3c107ac0d
BLAKE2b-256 264b3fff57dee52068223c41a4365ea21360ef1f597ce3828633251662b3a779

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.19-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 db1eb924e2eb3a7e04e035be1365da0c668cd270310ba81543a0edcff07508b5
MD5 8091e716ba58d1892f7865b16607b0d6
BLAKE2b-256 1bbfb46a7e93dcac77cb4ab7b15c2d04293b81f85cd0e9cf67f82ad032ee9dad

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.19-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 4c7604709588ca2b162769f28e98d7c0416c3dfa36c8c8fd5dd8b914ac3d1580
MD5 a5667e685eaff10f9298931a4f34aa9f
BLAKE2b-256 b9c99174d480897a98c05cd1fb47bc4422081441b47115b4ccd220e652383861

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.19-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 7c60e31cb45ad5b94693bdd42d41204e5c43fd5134c88018304d67caae815499
MD5 f4f31a6473c10c084aa9ad9d83f9605a
BLAKE2b-256 32706c6e29dd5caf4585e546af72995bf3f7adfb7964f12319fbba3863a05d04

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.19-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 aab5051711a974c41d660eb43cf51867ef147b449888e97583ad87a192075e77
MD5 a1b70b79a08240b5761275a5b6653ae3
BLAKE2b-256 c6a86fa42616e0443af96daab9abc4b60ad73a84fdd684cfdd98b571cf97f076

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.19-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 6b93af4b51372b4b3bab9003a210a1423552fc7af008775d225a936756dcb08d
MD5 09d6553f1bf93fb6fece002c6c91f0df
BLAKE2b-256 ec6df0e428c72ca5a663da83bf5a9692abac7e2990376b44a8b0deb83b8e02c9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.19-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 7af4d47b47269b24926333775f7b67e85923ba0bce996847447950582d05ff69
MD5 803778dc607f4c6c1d0f367983c0424e
BLAKE2b-256 4ee8d723e5fdd74d32cb460db97711f7ba615f89c7a4afc2b03babe22469cb5f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.19-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 609baec9195526b2628f203e7da47ef504b609d9f50e602289b4a871b2942146
MD5 7c65399e8c59dec55a914f972987d3c5
BLAKE2b-256 fdda4afaa638af830a76567366460c8bc27717832926a1bb7dbae2e4305679db

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.19-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.19-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 4f936b215b92d51c37d764c7e4c2833e8a6eb9e8dcd90787ef742f3ba4055e97
MD5 9f72f2f479d41a6cf050674c2405a077
BLAKE2b-256 3dcbb21d275a33bb65ca48180ca4aee74eb589ce6085becb9768e4d144a72a82

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page