Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.12.11-cp311-cp311-win_amd64.whl (246.1 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.12.11-cp311-cp311-macosx_10_9_universal2.whl (305.5 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.12.11-cp310-cp310-win_amd64.whl (245.6 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.12.11-cp310-cp310-macosx_10_15_x86_64.whl (247.3 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.12.11-cp39-cp39-win_amd64.whl (245.2 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.12.11-cp39-cp39-macosx_11_0_x86_64.whl (247.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.12.11-cp38-cp38-win_amd64.whl (245.5 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.12.11-cp38-cp38-macosx_11_0_x86_64.whl (247.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.12.11-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 25b0f4b5ba97bb5780e75c312c3d52787ac123ffe1a69b9a951265f28721df0a
MD5 8ab403dcde51fcc3de236339e6f0384a
BLAKE2b-256 6b581fdd48e6affe96ecc5ef54b8b681b746e34c7c5161f0db8dc31b96981fd4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.11-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2370bd960c48244f88b8f3cda924a47173cd1283c937185686125b97faee9c56
MD5 ff13afec65ffd82edb6dc9292a89e7d3
BLAKE2b-256 e9c5be20046572683fdac6170d2977151f5a79203caae7d3e0343bfc16a5e677

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.11-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 323d2c3ef2a240470c3f2b332e519eb6042c743b7eb65642a6b647ad50bdabdb
MD5 685e56e7f9dd11e293ad105131f2c5b9
BLAKE2b-256 0f868f26f4ab6cd09e7e2a149b2924fda993e9c445f5c911c91a702ebe7b158a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.11-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 3dec250388709c02ea30730e1875e2d7bee19f4686bb514352b53ff0e1d5ddcc
MD5 8e79de27d5667e458d002ad184290119
BLAKE2b-256 8ec5f31d64f7556389d806df5c25993c1e60108f700b53040c1c9529b7b79c99

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.11-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 0bfccb5eefc02a6f394b74f988e2fc9048fe48bb817b84d6c134a708b102dd4e
MD5 108398cf2fe1945e14d064e9700911e8
BLAKE2b-256 437607511a41989fa4d3375a1b03bc14f6c1a73518f0a55801069486e687d2bb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.11-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 c21cf14c916c929c4bf39adef67c626e0608798adfef22cb3a8e4d894be1ec48
MD5 94bef1efeb212ba65711cf389f907746
BLAKE2b-256 a2229d1c8cf34c5144fd8379b58299c37a5dcdbf0184a41a526237e0aee37a1b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.11-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 b43b22800767b70116898655003c938c973fba6addb2e7ab7d9b36ffae7c179a
MD5 2d540e8225f4f4b21cafe27dd788e4c9
BLAKE2b-256 5cf938c2c77f57af37aa01cda874616569887c8bcd2706a2611be51fc39ff9a9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.11-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 e012f2f0e002cd7654d5969dfebf58d3e0112524ecac6cab278461699fb4adb8
MD5 d57828d3f9892310e92d4c7e654cb4e5
BLAKE2b-256 4fd2b6d723bf1377096f1ad276353fab56334f6bc69439c71169d3c8c7126820

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.11-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 1239a0137cacdb53dbe8510afb377c3813f180d7718da06b2ddb03dafc7cc0c0
MD5 3c6748491966aa2b26c0b45e8d43fc3e
BLAKE2b-256 cdd7d903a5d397200fa7046688c822c6593d2f67537d5bb2e8633ad0d61212b5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.11-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 9d6afab6159b744c2d30beb5c5a5534d4e128256196e27b84c3eb824b662f44c
MD5 529cdaae2e4d5efe2a5da22298a5d4d9
BLAKE2b-256 9d2e01c0018d68c15745086525ea55227010452ddc022d5ef8ed98ddd23b64cb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.11-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7f9b1eda7bd9ac3b0403d83bcf4a4d71c77cec20f724fd75fd37bead4a9dfb0d
MD5 f67c5d003bce3efa59e2c25db7518eab
BLAKE2b-256 437290c5ca341c2f23f080536ee811ca30c427102fd1f9d26a78a0db085e3994

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.11-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.11-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 5c1a6350276330d41a7a0ab5c513ae985e08380856158079b01f7d6f520f4cd8
MD5 e30d344222b40751d46acbfc61500fde
BLAKE2b-256 e91a2263ecb7e5ce5d2de27b8569f7a517012d6b62aa605dad3c4df54b40e3c8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page