Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.10.27-cp311-cp311-win_amd64.whl (225.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.10.27-cp311-cp311-macosx_10_9_universal2.whl (285.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.10.27-cp310-cp310-win_amd64.whl (225.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.10.27-cp310-cp310-macosx_10_15_x86_64.whl (227.3 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.10.27-cp39-cp39-win_amd64.whl (225.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.10.27-cp39-cp39-macosx_11_0_x86_64.whl (227.4 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.10.27-cp38-cp38-win_amd64.whl (225.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.10.27-cp38-cp38-macosx_11_0_x86_64.whl (227.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.10.27-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 98158d12fd246a69f76f96512284a99201196a09c3b73ff529bdd3b125717921
MD5 eeb81f7f35e841305cbc1db410267f38
BLAKE2b-256 302b302d04932ad28385e26a763045b5805cff64a690f83c6796caeebf89afde

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.27-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c8c37db0edc27be15794a2d25c2a4a0d7efc5299723499574aad458a68a22eeb
MD5 5e2eb0069d0b7aa6d0d53febad966cae
BLAKE2b-256 6551a8b9b22a16e85b07eece74d08f807d73c6b3cb28f2326d99bccc9b5d453e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.27-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 ff158b37218456fac80426162aecef0f2f91aeb7c144d66b77367893d7579136
MD5 6d6e166a4eb7d0a93b7095a5304f5fc7
BLAKE2b-256 8690c520e2e6a1d87ab688df3b1c046a36f6644ec58e1a1ef3ac95f4e75e8dc4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.27-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 e2ff7a30b79cda0d77058900609105c4b2cbc028491c939261fe5b1a3b26b86f
MD5 3cc1b7d2165347a74774ae68b704ed05
BLAKE2b-256 05170d1f38385a18a9f26ee73ba1a0633e739285221ea4a3fa6a58e84383f3d8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.27-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9b883c38238bfff4aedeb3fc485804916ec1103d372c8ba721f95e82ab386258
MD5 ee373f30813544177048c0a6c711bebb
BLAKE2b-256 0437a2086e5b78be9aefa409b24599e2f2de4e3464e86e44f490fb47ebd5eaa5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.27-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 c5d72b66e8dcbc7a5f61169812a87062d4a7a19f6e5e3ba1b8915bf35196fff3
MD5 b60d0933d298bf3b69f4f781273cc29a
BLAKE2b-256 2b1b9c252b922d73a1651acd183f47d2d39e24b300bf18bb70069bcd12dd5c3a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.27-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 2d7eee82340f50400b2b8cd15dd1619b74cab58d05cac786eee79d16fbbe5411
MD5 b324a1ed9dd6bebaf241624a8de85982
BLAKE2b-256 54fb8642a383ea83bc9757c29de8fd71755ce177f6106767283aff4aebad2520

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.27-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 1a51dc85598a6b7584c86205050694afeeecdf96ae918a8c05ec6b33a9b7d256
MD5 1ef4f941a524fbe08859447b972e4c5a
BLAKE2b-256 01f33fd44e45e1c0e8a34dda882a5cdb317bfff2703c93f32b152f07e5a157a1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.27-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 f2b3212b5e0fb871ea73735100f9b0f72649ead3ccb06cbf2ac9e1d479b4ad81
MD5 cc7254fb6b59cf06718addd4f73dc9ee
BLAKE2b-256 a6574395da65ed9ea46d8eab577a274efccde89539f8a57c9ea5b4574d003ff1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.27-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 dd7e9b94d108a75ad9b7d4f97a455f96e17d460e5f44adc5f58e4ead8b5f2127
MD5 ac8bbfa15ebcf435951009963600943c
BLAKE2b-256 c95fa36a0a711b3a4b23286af3e12cca5ebe9ece6d8b76bc9d7a10826fe2f4bb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.27-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 19473b2a1c72618c29f63e95c1fd20b3a08de3bc96069280b0265690cbcc26ae
MD5 ed594a21da3bb869f99576dc1d358151
BLAKE2b-256 ba39957f86127ce2899e561495d03775b436a25c18f8864cd71d4b3b1d41fca1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.27-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.27-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 084d842f0e1b9250686910c5167d647bd1c41e2c12227600ee16e7745e0c4796
MD5 9c8691bfe6ab5d765f2d69db01b6d239
BLAKE2b-256 889cb9bf6b483546f788ec090a7535c7817330c1e8c39b4149305d58e3240e26

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page