Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.1.12-cp311-cp311-win_amd64.whl (255.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.1.12-cp311-cp311-macosx_10_9_universal2.whl (315.0 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.1.12-cp310-cp310-win_amd64.whl (255.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.1.12-cp310-cp310-macosx_10_15_x86_64.whl (256.8 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.1.12-cp39-cp39-win_amd64.whl (254.9 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.1.12-cp39-cp39-macosx_11_0_x86_64.whl (256.9 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.1.12-cp38-cp38-win_amd64.whl (255.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.1.12-cp38-cp38-macosx_11_0_x86_64.whl (256.7 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.1.12-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 677f32cf3bbe933fcfb38ef2ae0fe8e7ec71e5fdf1a3cd7d4699825e9f7cb06e
MD5 0cef4659e62602a1b26a1fa8f47deaee
BLAKE2b-256 b3c6ac3ce118ae044305a39c1175dbcc3afac23ecff842badfbef3f0f1737c9c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.12-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2ffada8f384e6a84395f65acb0ba65c5b113c862b96cb024837bc7351949e940
MD5 d49ed7f098a66a903c0cd19ff1d99af3
BLAKE2b-256 dae9d0443b520ee760f040ff3331548807940b74b38f28a19640822f2c4a1a8c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.12-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 d9fe194efd63542283083da5e4388d186f11c8fb0a00b4f6fb1dfb03b2a476e1
MD5 2d725889b1307ebd5833f000025af8f2
BLAKE2b-256 66584361d814640b9741a545919347cdd84ef545f05da8ed0af335f27152fa3d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.12-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 f3e5e299d11d968b8d3a8bf22cb4effda70460dd68169927f3575d1fb7b699b1
MD5 e25803d04997a6941dd130e7f31fdb2f
BLAKE2b-256 cd3a02d786f4904b9dc4bae9664e97390ce93064c3a0e7bfe30837ede589558e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.12-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 06d8f963e7b37f0bdcbba047e49797f348376f90784b8988fedbd6a574fb47b4
MD5 812860737a2ee6a44b031abd79c99d72
BLAKE2b-256 f8f2cd5592613679ec2ae77a8b55024a63a2517b30e8ebc904a5a20666d90512

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.12-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 52164481a326f19a8e1eee2ce9ac305456baca9564baa2899b29f007b4bfe850
MD5 00924a0b331896ac3d2d7538368623ba
BLAKE2b-256 cb26af9b2bbb31feabc79fcc28af6af14543d5b020ea6737cf2c8b8035488cf5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.12-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 f1193a27d051ffef33cf02686b85978427469efde6cec9f43969f5e793438258
MD5 3579d0cc5287c5a3002aed70ddbc962c
BLAKE2b-256 52aca51c6eb949042b21aea1d1680b09f23f2c3df78f6e6981c656087c6fc111

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.12-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ff42b9ca8c17675a09597a81af2a56611003412f77e946e035b9484b481c2731
MD5 28a466ae4279bc1aa91d8ce00547dedd
BLAKE2b-256 89a9c6d3d5e3a5ebeeedf4ef57822d9b34e4834c01608095c993c69d913ab183

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.12-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 e4d9bb859ee339af7ed25d4bcc78539a6e63d5d5b2c49df9be09cebd9da54279
MD5 aaf2d4de9aab61e31c3b47b648cf6076
BLAKE2b-256 7f575ff90005af1c4d006e82cb38844144ffe0883360b62b820d5063a3771c1d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.12-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 6b2b420dc8fb69159d45adfda8bb75db5de7939febd80ee180cf4f17ad370f1a
MD5 4b46d651d28ab799828dab85a55bbd6b
BLAKE2b-256 f8c226275f902a657840c603b1a5abf8d93f97b35d4d2f0e2d2b08a056c525cc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.12-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 eb9bebdfa327afc461db2d8cb2e6e681109a67a6852e1d71924acc01627ebb47
MD5 c575bc30383cfea2ac4feee0af544908
BLAKE2b-256 1bfc4a04bd58a66e460aedf0c659a82840ad90edc3cdf13e654223c5c7680994

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.12-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.12-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 ee54d1519edb90a6dc22d3495c16fd0534b367711fdc82ba5846e9213c1cbbcc
MD5 3489a94422f189e0726586b6f4ee6fc4
BLAKE2b-256 ac225a474b9ae6620cf265d86fe2af752b4a70a3a8202a3564e8b757a79a9dc0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page