Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.10.12-cp311-cp311-win_amd64.whl (225.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.10.12-cp311-cp311-macosx_10_9_universal2.whl (285.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.10.12-cp310-cp310-win_amd64.whl (225.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.10.12-cp310-cp310-macosx_10_15_x86_64.whl (227.3 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.10.12-cp39-cp39-win_amd64.whl (225.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.10.12-cp39-cp39-macosx_11_0_x86_64.whl (227.4 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.10.12-cp38-cp38-win_amd64.whl (225.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.10.12-cp38-cp38-macosx_11_0_x86_64.whl (227.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.10.12-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 bad3a3985e4496a92884569a9f56c9d7400ebee187c9355c876c1db11ef0a71d
MD5 145221008a9243a77544781775dc0a1f
BLAKE2b-256 a5fe4f5bb8d9aaff32c5488ca437c941a52584d886e94f1927572930fde4d071

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.12-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ef7003ba7e438549a3c8bdc7cd69af154362e9ec3a2cb219adeccf6c49b1ebb7
MD5 118658633a311a05fff37922e859ce4c
BLAKE2b-256 29d4991438c26450f8dde46158469f1298c7fb68b91d1f31e13e8ab7306b5b38

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.12-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 d813a839d57be2a1bbe1cd79066d351f88591a1e9b2b9a3033b72759d2c996e9
MD5 4c1c57bae8cb7d7dfd3e98ce4e32b3f4
BLAKE2b-256 2b9930eb3b87a9f3f2affdd966f93b653e00ef57423e8a45d52ea22abee04c03

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.12-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 9edceb50fa09979ae0fb2bb4c3c5aafe0a0a572beb77a9c1868a3f0e8cf0a801
MD5 aab51f71be8059f7c9e7c18fdbc63e71
BLAKE2b-256 22de798f5eeb75835ceff8c1d1dbf51ba5d9445939a9bf98b054f508ed10f9af

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.12-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9b68af8a2caee53157850959e0a98ddba91b61542977bbc86badd85e614f27d4
MD5 2a0d16200d7ccae065608af2ec7e3ba3
BLAKE2b-256 ac5711f49c48f4d8409a351be868e5a8beb92061bd0c29870e2f6d335bf61f55

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.12-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 29fcee4e651bb27f50969570adf60d6884a22ff379994d26e5d367db3ebfba72
MD5 7056797d58b6532eee3818a26cd4aec1
BLAKE2b-256 f4b0707e7c6d612aac849fe8795f6d98ca305959a4198791014225bd3d6a4799

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.12-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 7360b50010dd80e37ab5c2f34284aa3e55b7d9bb001fc7fc9c2693e5f3ce7232
MD5 1855f516da1af2e91db2e64e94df696a
BLAKE2b-256 ec207e0a38901cdab4fdc9e12176bf4d1dee22087088454cbc18f0d71d6a4cd8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.12-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 609f2579cb3927109c2680d5d1d519319b38a0582ed3c43db487a660dadf87ba
MD5 2cccf4ced2e2e91937cc1e55bcd894d0
BLAKE2b-256 806acebc4463121010936f2920b2207c30789483a7cc7049fa65596df285e985

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.12-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 321cf8651a6f52ec4f2d80be7b2a0eee2e49bb6496d5f304442309464c1ac055
MD5 b2b359e5ad49892c536a47567c2e1f31
BLAKE2b-256 2b4402f8b35084496130511a1d96c05a7b29c470e2f1afb78facc8cbeb7ff0b6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.12-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 48da2b52377d09346d968eac0d6f12b025cdb30441b20f84a1c549610d6c4194
MD5 f9cb7f2655fe79d676f21c6be13ffbc2
BLAKE2b-256 c058ac7d38cd82994ac08511177dc977cf7e35914bc42bbd11be290894c7a088

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.12-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b3eb6a13e4ee14c28724a9eec39b53b9065a170576244d5b90dd2f83391c2dd4
MD5 3d3c55f5c379cdaa13385928c6047cdf
BLAKE2b-256 885e8fe30a9d73346e90ad2766e2db72938d6e60179e9174b1af3e3836cc79f3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.12-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.12-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 6d388d5f3dca53dcb8ae1b5d903b55681156ebcd72589696c7bfb2c7fc63594f
MD5 58c5f7987fbf99a7df0167ca997b84d3
BLAKE2b-256 2be1678c59e9d3d16af1ae5acaf28e6e8ea62e80a18d80e1adb14029d8edcc74

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page