Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.12.27-cp311-cp311-win_amd64.whl (253.9 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.12.27-cp311-cp311-macosx_10_9_universal2.whl (313.2 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.12.27-cp310-cp310-win_amd64.whl (253.4 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.12.27-cp310-cp310-macosx_10_15_x86_64.whl (255.0 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.12.27-cp39-cp39-win_amd64.whl (253.0 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.12.27-cp39-cp39-macosx_11_0_x86_64.whl (255.1 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.12.27-cp38-cp38-win_amd64.whl (253.3 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.12.27-cp38-cp38-macosx_11_0_x86_64.whl (254.9 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.12.27-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 1f8307c93a886c9f31be9d4f1313b8188bb437f478e9c24e081136efdbc248d4
MD5 37b06a99f31605a8282d7968e836ec41
BLAKE2b-256 708829510f2d37271d0815a31db899a2bafebdfb0e03eeb8122ff7cc3df1bcc2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.27-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 66bb456c91b45db2e8344cc1ab5f0bda205d78945d366b3128915bef859d4390
MD5 a3eb0adaf9610a09cdeae3fee287f7d3
BLAKE2b-256 d2b8677951c716c1dace3ddc2b3357bd2b03afe8f4aa8b84a42850bab13d146a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.27-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 94f09826f830b2e06fbea74deb56617d866e76837ea013d2d1b4e74488c04b7a
MD5 0dca5bad74969650ecd33c2fb85a6c92
BLAKE2b-256 38a58f6a73bbfb26750ad83998b58b076a4f9b3f56226a767ab1251eab83f592

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.27-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 ce591591fb8ec608540cac793d9457d4126e3455b875effb01be4e07e8684481
MD5 0d37650ce6c595b05428f3387da8d3ca
BLAKE2b-256 0683ced8961484a0e6fa0c84389a5d448cef4bb72ae1a47e843afddd72a806df

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.27-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7b88550e9477deee24bf68df8e7a337990157546937c8d4e44003780f2987044
MD5 3ca8d60e07e7327e1bfe376f2da8923d
BLAKE2b-256 1bc5d3b86da3fea0003299bf002252d180605184a5fc4842e100608501c9719d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.27-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 787c2971bb9db83588143a7ab5f59101ab8a119568cf9baf5fb8237a1b83da97
MD5 8acbb4b0addcb71189d20cb07d1341e0
BLAKE2b-256 a1f3be2ab952b306ba8016aa849e76fdbbae1f2a1d4f62e2e361da44e0615eb5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.27-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 698c4f0def761b376b203624c5d00497f7639e106e694de1d18ee3586cac59e9
MD5 327fa5f2c692cdc7e8ee867abe18ef43
BLAKE2b-256 c47dda6a3201d22ea17b4582ad57e035096a05f7e5ae0f81ab843d713265dabb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.27-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 18623d2a321d106d8fb9599a33f542d5f44b5f0b2be3a3e59610ce70538c0205
MD5 92bcd00c06cce36ae5f3de826de4443a
BLAKE2b-256 2cbe47fda871ceb961a17d308bfbd4ad392c7ba2f86710d63e10847c5da40e99

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.27-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 b59d0511f0307cb894b0ce870573a6ce45cf9175c2dddb430840bfdaca9ae5c4
MD5 c60d244e3340b483f1e61ec59164514e
BLAKE2b-256 4548d9c49e55f6c80734257c431c68a60bbf4ca63f6c57978c4b0ad360ee7ca7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.27-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 22ed699e176f7ed85ee1b16b07ed0d04d1d41e40d058a4b2343dacdf845f7e16
MD5 8ca2dd6ea129f1cedf59ddddbec533e0
BLAKE2b-256 271ce9a4317d1ca165dfcf49dbb8aa4bf116a5779619b687c0c7c5a1d5ec3eee

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.27-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9fa4c1fb502a0b7bb45b5bb814c21f87ed7c0c395c7e894ffb507391f0205c43
MD5 0e57d3e7102ceee3fb38357ec1017878
BLAKE2b-256 f81bcc3a8165247748559d1e3c08219b2bed335f02e397d2d6209d7e0259095d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.27-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.27-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 09d3b45475e66c32d62182d21dc4d638aadcf070621e8b75f1f0330e0551ec77
MD5 e0177f3285b563ba0fb52a2304f84c7c
BLAKE2b-256 35e6a1c849ad5577e15d33a9a98f62c5510bc2e53ca0e61b174d7521672348ba

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page