Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.12.15-cp311-cp311-win_amd64.whl (248.3 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.12.15-cp311-cp311-macosx_10_9_universal2.whl (307.6 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.12.15-cp310-cp310-win_amd64.whl (247.8 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.12.15-cp310-cp310-macosx_10_15_x86_64.whl (249.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.12.15-cp39-cp39-win_amd64.whl (247.4 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.12.15-cp39-cp39-macosx_11_0_x86_64.whl (249.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.12.15-cp38-cp38-win_amd64.whl (247.7 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.12.15-cp38-cp38-macosx_11_0_x86_64.whl (249.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.12.15-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 87006cc9e64687244297a2c418aa6614ef7f5f440cc1d7193185a31d1ce576ee
MD5 435cc44f170ec6b831e20dd482df81a0
BLAKE2b-256 23f6abb37e81a44454f603041567a4097fa2a98173a839dddfb19a35dbec7e20

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.15-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 57870a45f5be86f0de33087fd69185c39ac664048c3282b0d4ca5603a5d62e29
MD5 588678b7147265f28ec6e128bce87ed1
BLAKE2b-256 7c91f28af89657f7ccbdd75a687220567e26e96a9d9ee1aaad75518902550661

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.15-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 99cbac3a5b7816e66fdcdb6408bc82eb3722021f2dc5059a832698f00ab7baf2
MD5 0ff5275162a8bed411c5d9b0f3e02110
BLAKE2b-256 b725257faf3d0e3dd8c70f51066ae31b9c9a3f5101475873e8a130838eb22ce2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.15-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 3519b57ff50dbd3c0f5f23bb19b2a5aa9ac338b175902e0c91ae2b5154cb3f9a
MD5 59a06c369608caafdcd1ef2358099f23
BLAKE2b-256 c9c465719ac7868b257b93ebcc5cd5b479b059cc323fe8a3bcd84c9c8a53ecfe

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.15-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 cf4cc02770fa6ace6d4a4515b4b6d0e753915857e96b4db6b2c9b0708a80da14
MD5 ad8901c0e39a4394f69af00f8c45185a
BLAKE2b-256 367ee0b5b254b8fa213c0d426c716bee3ed2d224d010742b7e7152a004c64e05

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.15-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 8ea50923b12759ee3c70f657b531fd9c9f8a9744d43611a5e3055a5103f75254
MD5 bf6f0fb184298db5e5a924d5bad03929
BLAKE2b-256 32ee93b5b307dad6102af822fa99c783242eb8593319985f56bce9332b4b8770

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.15-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 07fe11e57a2e06ef30003b07581946700736fea19af40ae71fc0cbae7ba28306
MD5 48fb21d8d816d084e71af254edbfeb02
BLAKE2b-256 d91e480320607947474bd421ad3c2ebecda286411bddc8bcfce5112d3f63c74b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.15-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 332197e0ee4341d48bb2730bd35d320d4aa773c64183cca7d8471c9a45a49d27
MD5 70a6beb6c206e30cf7a238d7f3fbfa5c
BLAKE2b-256 d735937d30f85ec5864f1422591110d004f8cfed5b2416e820b6f254ee205258

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.15-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 e30bd5228b6b089fcd1a8290764c1c422f2fc7e16932a96ff2c1cd6706620457
MD5 bda43b117bdbb9e0e03820dced60dc97
BLAKE2b-256 ffe981bb41188a16a3ba1de15031e6323a5158644a94ee3a44ad6298f4616742

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.15-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 fb5e7bc63c1bfff0b6d0a53eea38cc58346023140f6853f42f4f00690d65a83e
MD5 1158a99299939608c223325b35ec465d
BLAKE2b-256 deec9d28142cc72e1133607a84a543a307a70f55ba89328d3a64d0604efe69a2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.15-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 056796fe8ce41e67a47068341b78805402b3b26332847daa31f9851961ce90e6
MD5 7889fe3267f896e34bdaf2a0f6a94611
BLAKE2b-256 6f0b03158ba9a2a950580402823405da8b7ffdda61db890ce5de1d47b22bbb6b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.15-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.15-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 d83a37c5bd9f42547b74957a25549a8057ee602955f46c5639cd286cbfd93068
MD5 852fd37fa7b206f68efd46cef5f22ac9
BLAKE2b-256 035204f60af942d3d742b3122cafc6ae533c02c9fab55360620b6ecba2238d7c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page