Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.9.4-cp311-cp311-win_amd64.whl (220.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.9.4-cp311-cp311-macosx_10_9_universal2.whl (280.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.9.4-cp310-cp310-win_amd64.whl (220.2 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.9.4-cp310-cp310-macosx_10_15_x86_64.whl (222.2 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.9.4-cp39-cp39-win_amd64.whl (220.2 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.9.4-cp39-cp39-macosx_11_0_x86_64.whl (222.3 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.9.4-cp38-cp38-win_amd64.whl (220.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.9.4-cp38-cp38-macosx_11_0_x86_64.whl (222.1 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.9.4-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 3989fff8e5707b885ed8a5a9b78cf1d239a744b2b53e37a02cf4822111589b56
MD5 86b042a4b7d2feaa9466b55e2deb29eb
BLAKE2b-256 c7cc018c0ab02390e6fc7133bf47bf0827d8630f7e8ee367d5cc1480fa104092

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.4-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2e123da41e470023c9cc6a21291fce799315b48eb51d70cdd4b655cfe1a38b8d
MD5 0c8c04155ca9f214c94f9a63ff52f74e
BLAKE2b-256 e91d306b4b880b362061ab85ec1d3dc1249aef0cf4610182f27b15839ac2d567

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.4-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 433a7c13eb2728968e36907d9d9a0bca65b24169cf8efa00861ee69297530295
MD5 9cb949e5add0ad2b9963db24257285af
BLAKE2b-256 99fd23911b7dd7be7c846540f2a8fa7f0b33104a8f70bdd53bdc4f4e113a27a0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.4-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 485df383e5accf39650bdf8d52c7bce4cd1ec485459583ea28dfcd365d143b4c
MD5 e030f3a0b20fc61b7ee2b449d59f2d68
BLAKE2b-256 0a1538087433f4a7e9f2eb2cde2fac0ee1fa71788903760287ca9522983e594a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.4-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9dec5ea74357b33ba472a98e1d1edd8c83198d10ff08001c5c6b28d0d7fcda7c
MD5 5e75d852196832e71b5ef2b7db50ffdd
BLAKE2b-256 c202dd1f7d924832d95aa41588412455470d9afae88268e1f6c3ab5171044c9b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.4-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 be9cc98a75cb6d371b89768fb6fe3635c0b3cfdc74b238928a73f56b3df9e288
MD5 395dbc14f67be4d1cfc00edc09ae8f37
BLAKE2b-256 211567b70dc80197e57bd0bd8f2b04c39a51dd65b856b7f0c6c1c19d3071fffd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.4-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 e98d2ff43e9f805c5c0b4046f23219dbc434de413556adede4c177d1fa07f60a
MD5 5f5ceab21db42ae29a006e896e4e206c
BLAKE2b-256 6ab3967e8d8e2a9e209b75b7de091875caf2a9354d95babd8424c9ca5c01dee9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.4-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 59014961f84b213d545a745e026512119dc2eb2119bd017340183732928c1218
MD5 293b7ddbef8e9cc36e6e6a7deb98c722
BLAKE2b-256 b2409acf1b4e75f10a16d5d6e3473d949c77758e965b31d7c0dda42fc7d883aa

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.4-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 e66c0402cbe1bcadb56d012f2eb3716bb7480214562c70376c7d0146ce5fe518
MD5 0e4f9a47087f6b7a0da3253efa9605fb
BLAKE2b-256 4070e4fc5579b85a1decbb0202e1a64569c07db2d3ebd701ab6eff4eaa3229b8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.4-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 7c3350a43949f0b9758a2daf340b5509b62f8af7c9f1eea3b4aac297b0d31037
MD5 19feacfdd13bb7c3cef9ca58cad681a1
BLAKE2b-256 397558c9e4bae27fcb41523d9f06802711fd593dcafb42cfb5647f7a107da6b3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.4-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 065f5a47564334a864f3aba4ca5e4bace57ac9e3445e0694029f5901d1fc424b
MD5 6a6880dc54e90003776b3eb642cdd8ca
BLAKE2b-256 75f948a1794fd87c96cef5f5d225f438635ba257eb93ab244f110943ffb92d3a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.4-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.4-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 1cecbeaf9b56ae31c5c914a5da32e9c183801870a8fdda1b6bdea75f0b546b5d
MD5 b519b183440803af71e343d6d12b0205
BLAKE2b-256 1652c33fe460096c02aeb10975b603aec3132d77715df545579d55d653a7afeb

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page