Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.10.22-cp311-cp311-win_amd64.whl (225.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.10.22-cp311-cp311-macosx_10_9_universal2.whl (285.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.10.22-cp310-cp310-win_amd64.whl (225.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.10.22-cp310-cp310-macosx_10_15_x86_64.whl (227.3 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.10.22-cp39-cp39-win_amd64.whl (225.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.10.22-cp39-cp39-macosx_11_0_x86_64.whl (227.4 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.10.22-cp38-cp38-win_amd64.whl (225.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.10.22-cp38-cp38-macosx_11_0_x86_64.whl (227.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.10.22-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 5007b9f14e56efe9b79538f25a19927ba04b44110a0f07ed590f5e46fd0fbd6f
MD5 ea0c4aca4cb0ca906f130102e53fcecc
BLAKE2b-256 d21a415668b11707f552a02b34e167ae74da7c842012da2e9b2421f67d00698b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.22-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ab41878531e25fba38bdd8ca9fbbb1a33f67453bed3b4caf0c0fd217372fdff7
MD5 af5e1574c2f665ff22adf15db355ef21
BLAKE2b-256 9203ceea3012ff9a711e5320c130f0d6c19401b3cc0542482bbe13d757fe81a8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.22-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 9d9206ffa0a892dba8c3617bb62c1c1ffb976d7cdc91dd6749d90e678ec054e8
MD5 2c5765a23d5a563da38eccbf2020041e
BLAKE2b-256 8b752fc7c6a4430f5bb032aed4280559eaf0ab19b39d530b0617a8b5d865c039

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.22-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 f7782d43288c8eae7789759efc9272cc5f8feccf4fe42d6401f875dbb19491df
MD5 13b5321604fa8c658e2f0bd84dd110ff
BLAKE2b-256 38032417a261293ce47a5e57e8a1febd299471c9dde64733f4176bd6ec1c12d0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.22-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 07941228bf6375e18e7a52dfebb94a8d7077836edc09d81086b5c5ecbaacc029
MD5 f519788e6c409c9927d8cee0608c8828
BLAKE2b-256 7f7740bc90a85acd14e22e3cd2c2dc0b2799bd2b00b441a80fbd2d3ba829b97c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.22-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 2a36cb6cba0b08b87d1c8352b3ed4666df606183e1afd653e6baea09b5e392ea
MD5 426799f7881b6ed53be50b17c44acfc7
BLAKE2b-256 09127fb2297a44a22488bf453d0b168ac9da32ad2c3b48ebb0bcb142cba8b274

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.22-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 ef67d7a5bbcc28d5c1bacb7cca413b4a2ebd5bae60d15c66cf99e66d920c7164
MD5 b5aac6bfad4a12a28a814d3efc3b013e
BLAKE2b-256 3a1bc52443551869f78ecefeb0412ad7643c6d4ef6fb1fb271aff3bb6fbeffb7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.22-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 034388ba56300a8992ae74296bf2eab7a471cb1c41da1bbf49bd2b4dc77260a7
MD5 d6dc9a7cda2ed736cc03554aa5bfbb17
BLAKE2b-256 aae5af6463ffb7febe6140bba9845218d2437e04ded5798da267bdfd21d2a2e4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.22-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 b5814ac1e4f2ffad6ba6c4c4d1468a9a6423a659103fded9a082ae6957058e3c
MD5 b6eecb2d5bcb1e3479cd688451b419b3
BLAKE2b-256 7bd1e1b24e2d403a82f4a01ee63f384b3047c7da60d17a681c7a9017058685e6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.22-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 87ceb5663768fa1f971632955d646191abb2defa3e8955a27786300bf1eb61ad
MD5 f326904af25864aff84185ba12f5edb2
BLAKE2b-256 58bd6733eb4fa50a97b1023c73b4b7b4303c47caebda33cbc6cc5c7272839901

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.22-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 25221c467220e6ca889030979a0d25f6245ff13778b32a2b622fab67dc6fd1cc
MD5 6c492b1206a56edceb190b32bc50c325
BLAKE2b-256 936aff4aaefd1b9332bbae09b848f45bfd86658daed76accc39925872e189e49

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.22-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.22-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 79819670d73d0a65a3f83d2648ba71ffd62394be0c381ef79202d5d94c60a3c3
MD5 1b045e75439a690aa9f140bf6f89822f
BLAKE2b-256 5917569cb3efc42c0a91086dd82285a217be7d04b0abfad96dfad2d57c3a6208

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page