Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.11.27-cp311-cp311-win_amd64.whl (244.2 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.11.27-cp311-cp311-macosx_10_9_universal2.whl (303.6 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.11.27-cp310-cp310-win_amd64.whl (243.7 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.11.27-cp310-cp310-macosx_10_15_x86_64.whl (245.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.11.27-cp39-cp39-win_amd64.whl (243.4 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.11.27-cp39-cp39-macosx_11_0_x86_64.whl (245.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.11.27-cp38-cp38-win_amd64.whl (243.6 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.11.27-cp38-cp38-macosx_11_0_x86_64.whl (245.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.11.27-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 7a1daaf650de66fadc801892dfaa92827924efea8f7958c1a0547eb5fabefa7a
MD5 6231c4c5bb69873dac692d6ede8ed5c6
BLAKE2b-256 d8f0535fde9327526dce3bdb35b2150d89af1c60e1533394cf48955510790b04

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.27-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 87eb2b88036df8c51925091757ad23653a87a8bd1457968e2e81ce7c299aac38
MD5 2a9bd6f5c06d5e29e8f759389ac5652b
BLAKE2b-256 524bea476ce044168e44074e836b51747625017393bb3473742248b402d2dd70

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.27-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 b8747b2e455cc30046a251c4af304dc1f3960a8a264ebde3e5264ab61e2b2c31
MD5 35286745eeadd7f9e7556b362683a664
BLAKE2b-256 f743e2c0bd3fb1e19755d30e788427ac49a9c0b30af6bb830ffb8b7bc26160f6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.27-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 d8c137c83dd97e086553a5747c7f805e9f0fd96fce1901a8466438a64a1c327c
MD5 3daae1a5c08a84ddf3d617ebd7fe1e75
BLAKE2b-256 e3c4532ec4a461e2843b53ef710c0d0c2c6e1fd2e6d6def6fdd2d264d208ca4f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.27-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 6e8e438e9fc6d49b7bf39d61e36c3f45a07c0da559ad1d36c6f2d327d613cbc8
MD5 5fbde6e6fee77ff057741223ab3c20ed
BLAKE2b-256 8902dcfe6776b1f42811042746aef8f67a668c4f4bcd544cbd7a6f8609d81eb3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.27-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 83c4f36c49e66bd1345943be2e008e1e54fbef2d82b2eea962ce37f5b4e7df3d
MD5 763aed1b1ef888c7cee8f4a59cb7953e
BLAKE2b-256 40779447626041d2703fb76d6f5c93c3997824e17d0f082f2a1ee44b4d50891a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.27-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 4e8183424aa6f9e0b7bec07a038ff663d16b32b9d8d92b536297d8ecbbd5eea6
MD5 ec7bb9b0d499d9c77e03695e4c5f065e
BLAKE2b-256 e44054c3822fe4e3e6b3b42d56b8e6b104f98ec97ddf0a1322510aa1a588a388

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.27-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 933664938907483b4b7bdce9d82852394186887f23dde17acfc8d5c7d33b9246
MD5 840cfc5fa5bcd8ae3bc0cd9cf0d95d4d
BLAKE2b-256 a36865ecb66bb260d95f1d5844671e1a250952cab7f915bfc6cf4b1b34a7f275

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.27-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 2c4f4766471a8c9ed75201eb702747d1af84fbc4c7ca4cf8afe758e1d0742d8f
MD5 7891f150eb8906977d57b4a112c4c6b9
BLAKE2b-256 87a82a9bd218c779b92fe33d2bae505fc334e7cc81ff5bfdb39495a02b43a493

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.27-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 645f4e70c1035cc8f2e65cc5bf729fb32a9414f14b9dbcf4886f63c5c04af441
MD5 a1ff253a87ee15443f33ce98e83ee7f7
BLAKE2b-256 8a13ba2274026dfb196ce352a071fb19ee85613a95efc5ddb670101814187f76

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.27-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 75ce1eda1150b4bf5c3d26deb6a20320eaa68f02b5655cd3efbef03416c5a70e
MD5 ed1a748aebe08beb81434318c0878a91
BLAKE2b-256 b96c740c9f77f0df9114dd21fa3a4ea222912b68605b121ac5d60bfad0110d22

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.27-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.27-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 11e8f754e3b8c3a71ace34eb6c0903170b4cb31eb9776cb4797fceb09372348c
MD5 ac3cfc18dadd096c96b34c1b7173782f
BLAKE2b-256 82349081b8335529e5b3d45af24191d69d955179a6167855ada517d89886d67c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page