Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.1.23-cp311-cp311-win_amd64.whl (260.3 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.1.23-cp311-cp311-macosx_10_9_universal2.whl (319.6 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.1.23-cp310-cp310-win_amd64.whl (259.8 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.1.23-cp310-cp310-macosx_10_15_x86_64.whl (261.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.1.23-cp39-cp39-win_amd64.whl (259.4 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.1.23-cp39-cp39-macosx_11_0_x86_64.whl (261.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.1.23-cp38-cp38-win_amd64.whl (259.7 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.1.23-cp38-cp38-macosx_11_0_x86_64.whl (261.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.1.23-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 fed020dd8768959b8e714eb39554c4df8820044d32509c9ded57e05f495d3a9c
MD5 6982711d80bf80348b6dee24e3d89807
BLAKE2b-256 691dbc8c83259b710d4ff9524bfe8edaf37ab4733b57842418e33b533397589d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.23-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ca64e0f9d80f17b1cc6d59d3e4653a2e6126de8a772764edd94d3f5809cec433
MD5 682b90f51f0d9db0c3f2ef2d0a316091
BLAKE2b-256 f5b4c7d80e5b748e9dbc0b99f051a19ad6832f23e3d815fd5a0f52f6d4672d7c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.23-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 0e162a4ef13cc9065bd24d9f47d2f7a6f4064baf2f3525e7330f335138b4473a
MD5 fc145caedb5dc8be38dc35426a66e69f
BLAKE2b-256 16dfbae419f16722ecd4af7f2ba7dbfc127951396657e5577f0f063b5b6a00e2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.23-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 596b6d4543450a335368c8972e9f9ecca41f262f17985b4e6e5023da5a5704f3
MD5 529cbd1412d507c4936485a8a3ab3473
BLAKE2b-256 b016ad45b2b279516bdfe8901da5d8a24236e963190206b5a6b95fb2d7deebdf

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.23-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 8ad463be231df2def5c6fadf888b79071510e9cfb1690713bff790b73404256d
MD5 1fa830029405202fe5e100d34142e0ec
BLAKE2b-256 972a49a5b5112365d392864b0a8bd2d60ba01fe3d7298742363695c70ea03ddb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.23-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 2d34bbe2bd8240407b209e6783748e0e6965f4b9f3e4617bdf7c3cbad4ff611c
MD5 6b50004bf5ec70a7c9797f8f8ef100b2
BLAKE2b-256 76a7a45281d9205a28d6b4b1a95134dd6b3d6819122b16a286b05479af6eb1bf

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.23-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 be089cb89603e36bcfd94d353615af9ef80adad886f8f27ee3f85baf43dc1c94
MD5 52b4f883d3ac9e3cc1a62b81db9a384a
BLAKE2b-256 01d13126ee89ccaaaff1b4a5dbf7fef9ca0ecedca91c4ff4ae038486521f7785

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.23-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 82cc4debde9009e4c2f7fca34abf9e6a8d771ed39087cb31a60381cc7a02014e
MD5 cb0d07f4ef6de57e6dd3f2666e67364a
BLAKE2b-256 bde2daf2191abe72a4e6680d2b3eda945b90de4197b59de2462a27ef1fc866e8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.23-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 1fddc29d27a99b2f44816a9ba0a81e3eecfd71334f6fae7bf88f68cc126b929e
MD5 72297a135a1c4727787e496208e0c719
BLAKE2b-256 f9ad74a755eb8eb72f5649cb39f84ef62940084c1969ca6ec85eb43cfab13acd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.23-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 de9ebfe05d6e2fa37cbd2516f4ae35a0033ff04743beafa673b4746dad33118e
MD5 66c94b39cfe5257ca48af60e4462d478
BLAKE2b-256 148760946703473564f1ab962bffc4f568315b8ce01570576f469b123deddedf

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.23-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 07f93169da012aeb15e03fcdc84e5ac3bb3961405070a82c9cbb31115d7a8897
MD5 7b46a3e4fa87f476d4b0d457c7736c47
BLAKE2b-256 87e607b416f5c5d2fc38f8a6d5bd89b044370285fb4585c4badd42e9498a1c9d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.23-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.23-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 fdf5be866e3a325e3254140d0013f0b0d6511475b1b3fd2330434e5c6cbd216f
MD5 32e4fac72eebf3913120356880984674
BLAKE2b-256 d2d4ac1438850f0ccc6727999310f7a197b671e5e4948ff91efbe24197623b05

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page