Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.10.8-cp311-cp311-win_amd64.whl (224.6 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.10.8-cp311-cp311-macosx_10_9_universal2.whl (284.3 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.10.8-cp310-cp310-win_amd64.whl (224.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.10.8-cp310-cp310-macosx_10_15_x86_64.whl (226.1 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.10.8-cp39-cp39-win_amd64.whl (224.1 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.10.8-cp39-cp39-macosx_11_0_x86_64.whl (226.2 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.10.8-cp38-cp38-win_amd64.whl (224.0 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.10.8-cp38-cp38-macosx_11_0_x86_64.whl (226.0 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.10.8-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 a9ccf797dbc504740d197253c92f96353bfd362584bae220af926ef6144bdb74
MD5 806a3b57538db3751cae643d789f5744
BLAKE2b-256 a4ba3b2b644e2252b36875f800615d8eb5fa899b4bcd467b6ef1229fb8c87f64

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.8-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 937e2addb6c252086b11da54e6d24cf40a6748de09e28eae4e9f07d22e84ce60
MD5 dbe4c4759fd717c41bedfdf9b8948576
BLAKE2b-256 f2d8d7fcbb759eb11913e9336b2914f25acbdb5491ce4c496ac6685fcd83139e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.8-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 0f9316502349971f0284266d9d71a258e5ecb67fe364f98cbe97e652f6871dc9
MD5 ab70ce0543b9cb799f627079756886d4
BLAKE2b-256 05d1d2606c843ca85c7a808ae13eefb5f93b65efb2b715a6512149b53feeac2a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.8-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 ee212fab570789b450b7a60e05fd58b30150b8877965ae9ea8dab0622400f72a
MD5 41ad230e366b6e143e73a817e1d31447
BLAKE2b-256 24b7b8339ad68444d0cf0d208fe4978749482fec668abd1f690cdbe03c02a117

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.8-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9a701482b55e570277b6b2ff66c854c4188beeffd6d092d3edbaad31ea011d8f
MD5 757ee101f35b136b27b6e2d4fc6a8a56
BLAKE2b-256 7bcd51cfd847a66b40cb8ef992ae01d6ea596374123b6b13d7caa47acdfee41e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.8-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 38c837941b3b0a05cfb14a401768db05828799717616782a9425ad1da48927e6
MD5 6f150825ed94a5f7421b89c085a1051c
BLAKE2b-256 820b4a6589878ad51619a7dd493da4f3c17f75d361c8007a00ec65391ed0492e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.8-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 0d19609bd8b963c958f295724f653281494bc1359ebf4c02e568dd6b9f459da9
MD5 550396d093b9989f9c5f2db9e2ee1104
BLAKE2b-256 57bebc6856332013281d16c6d293c54cbdc1c994321363bfe94c372d128bffc4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.8-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 6bcc12ad0e46a1a3b068ff098726a9f33e5c80570ab59bdb112c015324aa73bf
MD5 38a1a978237363b0339867a9c6c3f9b7
BLAKE2b-256 dbfc1817383e226d2ee6b0fd959c52986b600d00055288bd9b6e530c5ef43f65

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.8-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 39e94c8e4b95fbc0016874f8b75d8b2b4ff24d0f87517bc1765fa6f381fd5c93
MD5 ab7e0420e6869457f1abeb9c02b584cf
BLAKE2b-256 a825c40a95e2b10225cf1980e0af50a73550b69503b06b2a8a2573af2a9c63c9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.8-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 04c7c4639548bfd3b9949ab56210ad566e92f7d07068098fe1294b9ca44d541a
MD5 afaee688c228bcc26e720e532b6f1bc8
BLAKE2b-256 8b3ce5b4183369acb50529082f38b27b1a511e8d9ed6a8862f5b4775a74728ac

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.8-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 1be65a8d5530f3f2df02f63abde26b0924d422c039a4d5816b534ace722119ae
MD5 a9fa4c9e48c38187ac344e47bfd87ae5
BLAKE2b-256 d32a922b19c6e08f46dc289edeeb98a197198699138eaa032d49dcb821ce3cb2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.10.8-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.10.8-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 688cbb27611f0784916fa1ab88a6cdef043b857e2a8b8d65ad8916bb54fe9adf
MD5 517bb6016b82fd80558b3d7cfa2b7b45
BLAKE2b-256 3f6d508b58fddcd536da01166e6a509519c51ef7c5554ef04f648cbdb582aa28

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page