Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.1.8-cp311-cp311-win_amd64.whl (255.5 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.1.8-cp311-cp311-macosx_10_9_universal2.whl (314.7 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.1.8-cp310-cp310-win_amd64.whl (255.0 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.1.8-cp310-cp310-macosx_10_15_x86_64.whl (256.6 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.1.8-cp39-cp39-win_amd64.whl (254.6 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.1.8-cp39-cp39-macosx_11_0_x86_64.whl (256.7 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.1.8-cp38-cp38-win_amd64.whl (254.9 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.1.8-cp38-cp38-macosx_11_0_x86_64.whl (256.5 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.1.8-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 4d5b7f8b98f7b0dd48be9062ecc69cb854fb09734f163752e100e9c8164547d9
MD5 09fcf50077bbdcc81e49acaa19c32758
BLAKE2b-256 9281a460bce1ee9d12e0988959c607593bc18ebbf6196559453af6e53f5dd356

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.8-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d5166185d98556be1cd948cb63c1714d6af257ac36c39ec53c1a006355065ebf
MD5 d8bf9d526c0b9cedbd5e36f3ce15cb10
BLAKE2b-256 fb3bc1d212b0e21e7873f4e142c8f0177904d3f3ed950d5042e52009084e3fcf

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.8-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 b3d3705f037f83640299cd2293b90e9b60fddf032f6add0f1d7a3c1408d7fc8c
MD5 f497d0430a32a2c6767a7b181b944e12
BLAKE2b-256 779cf3e636bea6eaa676f4472b9207f324a84a51adfb73c69016b415bb89e5bd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.8-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 c6fc84721ca4356f4bfe01518dbae1b8dfd78b7ba9f5eceb33ef5c1d7e7dc3a2
MD5 1595f29239cab07123eb21fb06e0a493
BLAKE2b-256 fada4c7e2635d3ba080235da1c8fd0d2df859965de25886fc6b8b38b7e1fbf2f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.8-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b2e066435d327d3859a199a72ed391ca2a3e1b96127fa57c50c5cb51bceab57d
MD5 5c2be38095f0ff3820dd066768d229c5
BLAKE2b-256 f54450b7ebc8812166076182c707a1692f056f4fd24b98a300c490d1873be850

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.8-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 3ebd954af7c536af7674d515edd484d6e325f4e5e6e58d4b86297d013029c92c
MD5 19767d0d005bc00dd8a146cd58a69184
BLAKE2b-256 0b652b7b870db36ec5109e1e8b560e399e24d105fac2b8fddef366729d170287

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.8-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 6eec8d4822f3c88de63c282e4347341efc7a1cb9374af6d753f4d68ac14aa3b9
MD5 9e64a30e6eaa95d2109724fae2c64d1f
BLAKE2b-256 415aa36c72093de48ac80c3630b8d7c15f4531f1d5da05bfded2f3b95a6790de

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.8-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 bb19bf75f3c241f4d3bf68d74a8b84209a9333e0ecd6bf0b127061de62eb277d
MD5 57ceb6ad82b7b73ea96af3051ff41948
BLAKE2b-256 af36f49b680e7f8d6251f2f914b21a5b5a3d544318efbbdf65b1ef2f49b1a519

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.8-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 bc207fa56e08d9e484af3fb69425d4690e8dbd0204cd5d3ffbb2c24fd4129937
MD5 a94622008d314006a1563203a72ef838
BLAKE2b-256 c10ce60d66ddb68f6b5d30f4871350617127d3ae11a0f65a2837e4353eaa140b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.8-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 9c8358bff49fac6eb3622e0825754292ee326a780da776eb94d8a411adbc53b5
MD5 87b4bccc6def6c8f002e276d0002a123
BLAKE2b-256 75b14ad736fc8a1d4871b2e4cd1ca6609c98a7e05d51151f7d04aa10c7010511

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.8-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9384e3242e61cb8b98e3cb20a605fcf4d4b598dfe45fe78f0fa6cf618308b52a
MD5 ba7d4278fd906d7a261049fa5bd424cf
BLAKE2b-256 01a8641a3d62b476d338737b12326332f2c75d154ee25e66c264427bd539a19a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.8-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.8-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 8d2fe7bdbcf1ab0d28f280fec82755ec94bca07534fb4bbb7405b079c306bfb0
MD5 7046d7f535b7d1bc0ad294a716233602
BLAKE2b-256 18de4340eecb6c25bd1cb9c6ff4497d0dc1fac6903b922f336e825396af1fd0b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page