Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.11.3-cp311-cp311-win_amd64.whl (225.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.11.3-cp311-cp311-macosx_10_9_universal2.whl (285.1 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.11.3-cp310-cp310-win_amd64.whl (225.2 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.11.3-cp310-cp310-macosx_10_15_x86_64.whl (227.0 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.11.3-cp39-cp39-win_amd64.whl (224.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.11.3-cp39-cp39-macosx_11_0_x86_64.whl (227.1 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.11.3-cp38-cp38-win_amd64.whl (225.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.11.3-cp38-cp38-macosx_11_0_x86_64.whl (226.9 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.11.3-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 57be6670c091951e791a8d7c98a4a15ab570ac3101d94555db395378d64dc921
MD5 7b41da83d883b1432ac988f38320c756
BLAKE2b-256 bc1350af6298e88b5cf5017c0e192131d2e3e2e87ca52df1f2148f1019831ba8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.3-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d6ff0c988961b6c4c298ba98a6f3161fd300a7130e51073799b1ad46f4f2ae4e
MD5 93f5018311c6d5c966729833fcc99bc2
BLAKE2b-256 be50b4fab7ecdab42f088fedff329720893d574a6c61628a4c11e0f9488589a4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.3-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 64a5ab1d44ba7f5482a47470b13052d8ae6eb189c9111c8d98640425dea47873
MD5 35684a172915024ae204f445d4222859
BLAKE2b-256 fa9e6248b903f97e23d126c0dcfe7f2a27ffb3766032bca6ba1e5a9f4afd9f82

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.3-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 6884b77b0a7fa4f3fdc3a8c914b3268f8ca0fc237c6e6f4aba92e084887356ee
MD5 c9f2143870c0ffaf90c656f96285a270
BLAKE2b-256 7598090721521447980ec6ae5323857ccf092999111c19e401ac95ffb7fbef23

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.3-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d1a0e6c3cfc812b9f803703403d4ce652ec346b8c3546d8a9f8d587ad296df23
MD5 128a4cfc76ab02e93f64cf869ec1dfaf
BLAKE2b-256 7117704b71e4789b70a3f3778a49547fd2aed175d39a289ab6c9f3ee3d318dd4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.3-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 3eace197b4875957acad844038bc8e6391bb31909add631fd80a658479335dbc
MD5 4a72f982ab1348188cabe46ff501d9e4
BLAKE2b-256 e7bc2da929c73d7f7aed6aa98b0e5411369bf15fcf04c2ed717dd09faaef7d84

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.3-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 7276b995a6e33f920508d10880b761bc735d9090cd48a08989629fe3e328a784
MD5 4be9b1338b97569604aca21f477a6245
BLAKE2b-256 d07ee284f5690bbd7d50838a8a62adf6609c72b332f300727c24aa8fb783149b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.3-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 73a2873fd09929350e5d56ffb4b5db6a4400e628db8fb11a5cf04f1b7c345411
MD5 a2efd753acce8441dba9202d6d690e06
BLAKE2b-256 35ca4deb2ca20faf136e4e6238c29590bd0e8429d33bfc812b99fe1031fc60c9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.3-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 e3df747ae223a23b56d3ff5ae06746b62f1676611eead084eb3c7837f78cd4d7
MD5 62bef6c1dda7d7d9ee56ab1012118abd
BLAKE2b-256 b222a210e3b776fb20da3d69dd786e4f023de12515b9105e0ec639438a1f66cd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.3-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 bc825767e3291f582beccbdb64a23172f63fdb9d2be436563defef8d1e55a368
MD5 c7f6219e8c99258026734fd3a9ef3e43
BLAKE2b-256 6bd0bc01947f2dd6fca5fc471089312e277d781f7cece8ac5becc96011d0780c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.3-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9a15d940e9e9ea60ce1a991332f0cd8cba6a33f46228116e7bb98d473a26d506
MD5 1991b253713ac2689cddd0ac9d08c590
BLAKE2b-256 f4871e4e63957c849a490124244e776d249bee8d4a27052975c1765d5dc525d0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.3-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.3-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 af52145f562be4f3e84d394c796fb72ce8b5d1244e7d2da5b7da464a2059a08b
MD5 fe8d7ced8d4bd02fbc67e4655e074b0e
BLAKE2b-256 cd1da3b14ea82f29803c78a22faacd3f7e27bde762d690c59681f04a714dc8cd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page