Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.9.25-cp311-cp311-win_amd64.whl (224.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.9.25-cp311-cp311-macosx_10_9_universal2.whl (284.4 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.9.25-cp310-cp310-win_amd64.whl (224.2 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.9.25-cp310-cp310-macosx_10_15_x86_64.whl (226.3 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.9.25-cp39-cp39-win_amd64.whl (224.2 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.9.25-cp39-cp39-macosx_11_0_x86_64.whl (226.4 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.9.25-cp38-cp38-win_amd64.whl (224.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.9.25-cp38-cp38-macosx_11_0_x86_64.whl (226.2 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.9.25-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 97ffe1d9d06ba81ac1f8412c954e054149cc13f996b91572d10c2e3e1eec7572
MD5 3b5a4e6dec37eed6ad08bfa49d2a21d2
BLAKE2b-256 af804324fb784f5a761989554bfaae0a855d53aa2d98545aa96fda85bb5cf7da

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.25-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 6cefa5259b9c951c06a4ca46c32e089ec9adb5c28bdfe0f26203c9b23b771e29
MD5 72fb0bc69e9338b7184d8b8096790efc
BLAKE2b-256 d9d9a30fd104ae02ea559ead593e7d09327efd68ff7aa9f5087455bd1445de52

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.25-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 d176e087910cbd4fa969951d744a06a1a39d51c5330fa0234ed844bd700853d7
MD5 2e2def1d480989cd6a0a6a7d79ee9804
BLAKE2b-256 3968078e81c14c68ea96bd8dad8004ea2d1d44d09c03cfe21ac44b2331688e37

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.25-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 eaed1f87497b1b33ec654c62bd094e6c8ceec1478f5989291b2476f0824cea4d
MD5 4c7bd4ffac27324c006f7d5f466f4d49
BLAKE2b-256 4d1853eb07369c56e599df5042b7f6ac6f0c91d3ea6a399b975e803d23a7c1ea

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.25-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 20890b763bc31b2f6650144d244a40da75bc4d372b75b028fee2a4649cd0c2f2
MD5 02c8617ef5ee226e7065fbb1cb6c76a9
BLAKE2b-256 c7beec0ecb4b14b8d3e0914cef55ee6bef49f13f363a0c45ab67bec323027240

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.25-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 8d9b75ebfc73ddd9ce4fd6ef38eb251ad859dcc4d6b4febc41aeb301b2dbbab5
MD5 671c2c71074a43bd2cb3b7586b0597b7
BLAKE2b-256 11968cf968986eb1a19b7068fd2463a31559ac65a28cc0cdb99d218cc5c19f8d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.25-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 1dde32be578d42db5ae5c13b5ae5ebe6c7e26393b6345ec00ca3b0b43c48c83c
MD5 5cebcd80fc26921710c291ba4c190742
BLAKE2b-256 0e3bb062ea2fa8d5551df06c85ec687e85da6de2a083bec0a7439c57ad0045f6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.25-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c065f62d5b50dc5b2fce49a8d2f20b5742efbaf2dc5502f47fe1d3f5f246ab5b
MD5 6dee69886875db8450f39045c5bcdd19
BLAKE2b-256 51675da7a4e64d95997a85786cbe49af20e2f6bc28c29145d9e8f77279080b20

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.25-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 cb909911b9179084c335d5a1df637b913b9f1b922e8c38fa1a018cc1d9f0b820
MD5 47484988a4acf6ade242a284346b33ef
BLAKE2b-256 6f10da0a6cdd3d44c2a58b4439e95d3db00008a3cdefad0b14a749b8ed823e49

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.25-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 d8c635b3c78cc9ae5613a2df8b1cfcfad8f226d2ec33b81ce417cb5ac40e90c4
MD5 0b97f668778ebb16856f81e451c40538
BLAKE2b-256 3d9948cccef13bb5c6123dc04d8f6bb37dec382441bc40b955d3184d5fd3422d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.25-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 38fd6d36304e6628dbb65004bc7085c417880bcc4df0b2eb375542f7e5eb4031
MD5 fa2ce1fc343903c80589bcb83ce794e2
BLAKE2b-256 89bd4b4a1c916ad127b2b96e2b9d2fbcd12118bd7ee17a2234b232e2f2324475

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.9.25-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.9.25-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 9f43bbdc62c7e8150526e11b916b3927954a8dab5e9a2a03f9a54bf4aaac7cd6
MD5 841ed836b679707fc8797f7e321aea31
BLAKE2b-256 8929a97691ff22d90aad7b9661deab8d447f179c920bf9b1d565336ad14e0c91

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page