Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.9.11-cp311-cp311-win_amd64.whl (222.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.9.11-cp311-cp311-macosx_10_9_universal2.whl (282.6 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.9.11-cp310-cp310-win_amd64.whl (222.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.9.11-cp310-cp310-macosx_10_15_x86_64.whl (224.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.9.11-cp39-cp39-win_amd64.whl (222.4 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.9.11-cp39-cp39-macosx_11_0_x86_64.whl (224.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.9.11-cp38-cp38-win_amd64.whl (222.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.9.11-cp38-cp38-macosx_11_0_x86_64.whl (224.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.9.11-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 8d5affbba182ae791c84564971fd002e79716ea1175770cbf21a7fc64476f48a
MD5 856e604cd31157b02681642cff59ef39
BLAKE2b-256 a627597ce6b27f3685875208c385f7851f270ca1dc50a9ed373e45c014e213e8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.11-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b93d2de85536c213c02367602444fb73b813a4583a89b84d9beee2ad89f43129
MD5 98936a6ea92801312d28077b88e4f674
BLAKE2b-256 722080393fb147317bafbd546362709d95cd520e086eb912429f648d7fd7a5ba

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.11-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 95be511977360ba7c50a6456fb634f775a51ef3a1e60a139592109d6d775c62a
MD5 278cdf1b7a41d925601f23ed08912bb7
BLAKE2b-256 b2e8ca5f64984916f79c876d9f3368bcc2257859ac3bbb98342c5b3a4324d7eb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.11-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 c98f7a86abfe6e9be231302b988b375c66ba1941966c4f18d13aa2c0cd101f11
MD5 8ee356756e80114609207834a945b3a7
BLAKE2b-256 fbd7b17f26438548cc949afe5c8059201beb2ce7186a139da13c599ead5f1502

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.11-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b3e461ea4f0c50d549476f404237628b609f336bc738141226b8a3a42e8108f1
MD5 8a30b04f4eb762e14719cbad8cc4fc5f
BLAKE2b-256 6cb7a2c3f982862f48aa1147cdb7f424c510a1a3aef8310aa993aeb3825b290f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.11-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 2a579f5930514dbda2629de17fe4e2a01b9a5ab03bef0bb8c812779779dc059f
MD5 d29bb89e98d4c813d0c0bc7e0aeb505e
BLAKE2b-256 93cd6a5bd52f1531cd990c76e67ed4b51c59997ca87502fe082334f573faf4ac

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.11-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 79769c4c90d1fac3e774def040bb19bf54dba75ccd1b1f94a62b2e8b3538b976
MD5 874bc20fa63ec899f44a14d508fc0018
BLAKE2b-256 a26569d1fff8fbed8f0653543e0df2763441eddbc4ec0b8be6e8fa143e52f804

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.11-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 80f146695e6aa0f6ad4ae6b6163f53b6fbf51d9703eb6f59c661a90a358d0ebf
MD5 073c5d76bdc5ada8e92cc6cea25c8217
BLAKE2b-256 d3140e5dfd1074c428a22bbd458d93b8debf96b8ed3c11f3adc89ba9f3062b64

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.11-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 3aa78005e442ef920bf80c62bf55ed2abd9e941eac5644e779bf6d4a9cd48fad
MD5 c9410f41d63c601fb97689024446b600
BLAKE2b-256 24ecb131a67e89cc93fe54ab10a18937827699799c9de8a5711fe8f98efa3809

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.11-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 fce1a2c43d4760632a89623b411d7583c38c44c918542558a012b0bd63477650
MD5 2536eccca1e2f882f9f0cdbb7a0f9d47
BLAKE2b-256 b0c0997e0bdbff7f834b99d6f3e520a47a330142728991223253ab984acd1706

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.11-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 958255d2193a68f1bd4fa466cf4794b82d1963e782e70b221e16b66fcf8fe138
MD5 b82703fd39fd4f87e364c5d21da70d62
BLAKE2b-256 f9d6daa808cd8ac88046b3cc86632252846535a621d6437397455ece04c17038

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.11-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.11-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 c72e76f928885948e6dac06be6f18d0f3506b8dcdeaf618df7c069d691569cd0
MD5 3c556890ebd164dc109edc9792ef8e9b
BLAKE2b-256 d8875e8f6ac492aaac13e1d8a548de46172a82565014b0ab36dcb5c4aa7c6846

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page