Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.11.14-cp311-cp311-win_amd64.whl (226.2 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.11.14-cp311-cp311-macosx_10_9_universal2.whl (285.7 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.11.14-cp310-cp310-win_amd64.whl (225.7 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.11.14-cp310-cp310-macosx_10_15_x86_64.whl (227.5 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.11.14-cp39-cp39-win_amd64.whl (225.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.11.14-cp39-cp39-macosx_11_0_x86_64.whl (227.6 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.11.14-cp38-cp38-win_amd64.whl (225.6 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.11.14-cp38-cp38-macosx_11_0_x86_64.whl (227.4 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.11.14-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 cdd7b376d28d9a5335bc5313bcb92be26f08102f2c2951df044e3f9e470fecf6
MD5 9770849928c7258138cae0255a2767b1
BLAKE2b-256 96d3c0cc32fe18bd0d697942287b3a53fb6a0d5049d2e66ed198511257dc277d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.14-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 5673d18168f9d61ba550645dcca8a5d0e81d4884279509d1edcdcd78f011b575
MD5 f218c89ef42d2e1c0d99e3354a651333
BLAKE2b-256 9dcabfe0ccafe1f49437a49e626e68188c49e2dedd5e9c40329d0718203546ff

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.14-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 8e8b7fec50ded15aa19d27eb2f63cb897a70f36033474dac6e44cc709521f65e
MD5 4930a5ffabf10a6c5ac0dd17c71d0546
BLAKE2b-256 fd283da5313a9804b6c3a8a2668fb0d8d939712c510905492763f0b940a79619

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.14-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 ce0e995589ada2f99b002b1ecea9aecc64f30fabe89857e7eb55321a11b1da56
MD5 b63fb4cfa1ba1bb8c91233cd5e30598d
BLAKE2b-256 ddfcd0e878f4e01c7aece807f151894a63490801bf8fd37e7d21806c59926963

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.14-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 e197a9abc7a650a5f0b9655515c2805d671c43039e1d56946c14be0a48bea20d
MD5 21e950a70f6366b4b6be100774c14cde
BLAKE2b-256 da0239a7cb021f08aa41e7b316cc09880cc9e980365b9b574adc7fdffcfad74b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.14-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 2813303df779728b2f9e0febca63c53bf5199fde4782a53ec10234142900c7d8
MD5 76b5855e4d7cc066cded4b0cfe3a2d48
BLAKE2b-256 3b0fc2024c4526e81e8e7b741682673ab0eac3d69735be52426482cfbd2a128a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.14-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 3002aa5e528dfc9dd586dff83c31beb3d21f3056e09469414bebb8e3a82fc2e3
MD5 d4ec3786380482af19884359b0e507ed
BLAKE2b-256 ea30ac993bbb0b377c04c25a8dbcac707dbe1afd654908ab64d192e050385ee5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.14-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c6406c3603f1d68cc8298854cb445c469d6d326ee60784e4c68264bfdd3e2af5
MD5 31d1bce550beee29879d66960e571d88
BLAKE2b-256 65c0ff7d556c23d2d67dec77baa7026793ae0dd43d4b7eebdfe0e360fda78f14

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.14-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 bf115149420b3886d635c4a6adb3002cb94df4739e2ad26bcd489358a5d0ae9e
MD5 4d03e4889d6e74c1e897ab4d6cbe3756
BLAKE2b-256 9bfd1af295c73d9b60fc0f05a184094625c9b495292cbc04dcd1ef0d93b994b4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.14-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 c6c99cdc620b47b74f134be065b82571ed2eba0dd92e7fab875e90c4485a760a
MD5 aafc4aad6ed428f147e77b4a20324741
BLAKE2b-256 a4173c3fc830bb292ef0401e7f5627460ce1104de41f40410e969fb8fce5e66a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.14-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 556846c1d230a3e3734f470a1fe9c257b99bb326c6618faec43e328cdb6fcd60
MD5 622f3588d94beacb91e015ae53047819
BLAKE2b-256 61baa4cd17338a50d10249e91e194e34607ed71bd83989e2f37bcff947f97b67

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.14-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.14-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 51fd8f11149366f962aa153d46739581ffe662225ec634e9ed61a00dc4c0386a
MD5 ae30ee9bd186476ef4701bcf32f724bf
BLAKE2b-256 ead3d916938cb90df4f70064dd253e9ba86e875a483af2d5cd72e1f484ed3aa5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page