Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.12.5-cp311-cp311-win_amd64.whl (246.0 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.12.5-cp311-cp311-macosx_10_9_universal2.whl (305.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.12.5-cp310-cp310-win_amd64.whl (245.5 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.12.5-cp310-cp310-macosx_10_15_x86_64.whl (247.2 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.12.5-cp39-cp39-win_amd64.whl (245.1 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.12.5-cp39-cp39-macosx_11_0_x86_64.whl (247.3 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.12.5-cp38-cp38-win_amd64.whl (245.4 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.12.5-cp38-cp38-macosx_11_0_x86_64.whl (247.1 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.12.5-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 f9c427912de39a3c69c24641b00d546dc7707274122d69cf1079bb96c4ad93c0
MD5 02b167d5e17a084a5836bc750e10ba25
BLAKE2b-256 34f08fe8a3e67752f62d2f8525f226203ab151d5e790dc281e3a6d823f4199de

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.5-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 e2f041b314033dd4ba575e8c743ea6095ee492d7e7ec34986db9c6b7aeb6982f
MD5 b1de40ee327370875905dcc3f3a348a3
BLAKE2b-256 5d97b6f868cf4d6242c46bba6f9c4aedf57f0d6099233b47b5d9aa8500d04841

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.5-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 fe22f2337214abf0e9fd12182ad55cfd95f502bf0f047f2853a5217682b7763a
MD5 e49f281334f35539bae3432110656ef0
BLAKE2b-256 932f71bef291838e9ccb19464b87e3342f5bb23e4c1a206af0e29fe027b11344

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.5-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 070378b15a2a05843af347e058f72c4efc2ad09f035cff1c4b5777f455af7b24
MD5 e7d85cd53dffc573503ff8fc032e48ce
BLAKE2b-256 c6a761be607a2a9019798a0f2db1b8ba0aa5cbb7cc4a7a8ba40c3572cf7b0468

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.5-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 72289f2abc89873e82bb02490c04dd855982f72feac674ff73f666752fad3936
MD5 4a714ffd3193de322dba9ec7aa2333e8
BLAKE2b-256 6ab5cd0cde508f852298cc75037b065b1420c6b422569f82deda2369bcba0125

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.5-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 3e45c7a4af247f2a5d0147b513f4ad2b7b6310b110309696a9c096c8a60b7455
MD5 2fb77489b6f1202ec32a4ee7c50c659f
BLAKE2b-256 72abd15d5195a4a77df745a940475eddd0f26a701d09b98c46a3e3b3f6c69fa1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.5-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 52d2c7b271dbb9a8ef5b382176c14c0418db8d4ec364b6d9f74bd8daab2a5daf
MD5 c2e45512412b7cb43ed34e0dfcab7047
BLAKE2b-256 ba892e4af8747d4eca4ae5400d132aa9590646a00d4a66c8a1cd0cad485cd771

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.5-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 734d9f1315ed3d59ef233a47e93e7e0b44a14b1de30f16ceb78807b6f96795a4
MD5 dda456e6b64510732d8a71368e71c09a
BLAKE2b-256 574f65b4260763c6b4e8f40a1cfc93ebde345089c3a2aaabb2f3010134c803d1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.5-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 48e7a7ea2fa6da7631e0b61ae094aac2307dbd046f07c02400fe5b3051741f4a
MD5 a287b231856d12c8644e067e0b7a6b4f
BLAKE2b-256 ed29398e0b1c5153f5e70e0f51e1e9221094a4dd55d1c3246c12bfb6423a1510

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.5-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 d4dca07952fd8cedcf33f3534e37574decc0ebb5c991824af331832864e5664d
MD5 e5af5675b09e2b4c781302bd9480fcb1
BLAKE2b-256 0f4093faf29b82fac1c01312a68aab29a2c98fc87cc7d111e29f7b5001eab6ff

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.5-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ee3720e52d5c5ef9c5260c6b6f69a7c8bbbcc534b13ca5874756b8f57f26e589
MD5 6ce5b169f295b4922b6e9fde3de0b29c
BLAKE2b-256 e0d683219ed44d2e362722c0273148028d8340fbaaf9d2d58d43fee105a691b2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.5-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.5-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 791ea1f2b78e32201812b8b0ff3b499031ddb994e36b3359ab08310a497ecc42
MD5 84d545d6e5e7244038d0f00281e5767c
BLAKE2b-256 e7f97b83db0a5d9598069f00218d04ef9a624e490b4ee52bbb8eefa34147b75a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page