Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.9.23-cp311-cp311-win_amd64.whl (224.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.9.23-cp311-cp311-macosx_10_9_universal2.whl (284.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.9.23-cp310-cp310-win_amd64.whl (224.2 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.9.23-cp310-cp310-macosx_10_15_x86_64.whl (226.3 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.9.23-cp39-cp39-win_amd64.whl (224.2 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.9.23-cp39-cp39-macosx_11_0_x86_64.whl (226.4 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.9.23-cp38-cp38-win_amd64.whl (224.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.9.23-cp38-cp38-macosx_11_0_x86_64.whl (226.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.9.23-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 f66a1cef5c8d73746d1bf1902b672c15bca43ea40cc64ec13626c3576e7c4e1e
MD5 8471d09d5f468c92e7e7b95a41d8f001
BLAKE2b-256 f62e55973516e5344ffb2daa5dd256c42e7694b6ba6044fed821665eb0244199

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.23-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 462e012a6bce2ddf720010dd62dbffaa24cef0f2bcdd9b2933e831bd6cf19b5d
MD5 07e3c6fc39104ecb03d951fddb820fb4
BLAKE2b-256 0455589e8932be7a189aa5f1035ad7110bc0aa48aa7086a819752651a0a83a06

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.23-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 9a44a37cf9336b07ecbd652af20eb9491ba045c68d29d99c42556ba7a4d95773
MD5 b727318b1dd9aebfd2ec9175288e9d4a
BLAKE2b-256 39462643cf093f6c48d6de1e832a07fe2bfb2b329f49092c801c9e1a5fc3afe1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.23-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 b508ff5bd9e8c18d1b0e3dcb3cf85423367f184aba3522fdc7ee792a18fe1a08
MD5 e4a05e72c4233977444591b7ef4e179b
BLAKE2b-256 fd0dfba7196d4d040576a0394f3c2d4f96e6097bfbb5f2d8eea281d9d3b54195

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.23-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 28cf62d7ce7f814d7e5ffd1841133459e39af17dc955d148958bcbb3ef59257c
MD5 0ac3fd25b0f77dcc49d2632c25a70a3e
BLAKE2b-256 6467eb91097322d2db8c6cf6b7a0922e24e1abf0aefdb0d4728826c41da13bba

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.23-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 a3192eeeece46831873e355c87e78a61748404e23eb0a9ad0707ff9f7500f395
MD5 cca8f8f6b266b321f282c6c69b06e833
BLAKE2b-256 13110900636489755e4d112c7dedf80e82ecffc90a271c38217c702078e440c3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.23-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 48361a3ced1dd91595f63fab97de449b77c40b07b44ce398826b7820ea669e12
MD5 54b32cda7f6921222f2ed4bc4877e199
BLAKE2b-256 b9f7095b21f64de2f7d8da3945932c6e6d03f410fea2da2c9f9ab0f79b5402cb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.23-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 304158cbf629102ede7bcc0486b5604e0876e016985a345fc8e0aed2fdadaafc
MD5 a460a428077f5de9f52ef55495da6edb
BLAKE2b-256 f7bbc50ee05e26be9f8bb3f875d6ee2b651a69489fe1a9c6af4cb5290578a8b4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.23-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 53b5e90a74e084f5c22d61f533825d1c1a5e31007575c478e1acbf8f5ec6f9a4
MD5 938fa8345f8d9eeb8f70ca249ae998fe
BLAKE2b-256 ce0b4eb67f44b11cf80f4423d167ad9fd9fa95aaaad2fc8b4bc316bf18775493

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.23-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 8ff81567bcdedbd0d17181551098ecb38d28d6783ea6241a420dc896d6516b54
MD5 42fdb64ad55d5e0931722071e069927b
BLAKE2b-256 5fc8df05dcc9a88f590b7177bc5e17babb8a9b39c550670f07e653396ce3f114

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.23-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 728b32ee315451161d7c69ba5d75a0f1216c0447dbbe00ec329224faa049224e
MD5 a848dfd1d0a17808a163b6821dd8cccc
BLAKE2b-256 8aded1cd440eeda556c371e432aff0f9e60f7d15e48004b30e25dbfc7681aa3d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.23-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.23-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 66da91aacb90bbb324900839c90346e5a9478843c01dac1761feb7e21380caf0
MD5 50c02fb0ea92998e6767a9ebe8db2d4b
BLAKE2b-256 c39d80ed4ec6f2d38cd26b5bdefa646dbc2be34949cba5ab9c32339c8f28a3a3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page