Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.11.13-cp311-cp311-win_amd64.whl (225.5 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.11.13-cp311-cp311-macosx_10_9_universal2.whl (285.0 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.11.13-cp310-cp310-win_amd64.whl (225.0 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.11.13-cp310-cp310-macosx_10_15_x86_64.whl (226.9 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.11.13-cp39-cp39-win_amd64.whl (224.6 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.11.13-cp39-cp39-macosx_11_0_x86_64.whl (227.0 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.11.13-cp38-cp38-win_amd64.whl (224.9 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.11.13-cp38-cp38-macosx_11_0_x86_64.whl (226.8 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.11.13-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 f6a9c4d1d58a5d4cb61b0d4bbdc51695687136a045b0ad5179f178eb3f7441c2
MD5 a7fe8bc209a6558f2c8256d4d379fb77
BLAKE2b-256 b2b32a05de952639fd07321d00983edf832a1459eb154ab3fa999267810c9afb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.13-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 868ebb81dc10ba7cf42fe06d6856aaa92f4ad6a3526173cbff1c7c967246920b
MD5 6357e365b27be7ca88a50311cb153140
BLAKE2b-256 7e0101059d7bbd78de3c074a06274396c99255c7093c0c0b29da7ca197d32ea1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.13-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 b6d9a64651356d9e966a063f36c953ad23a17f0a22c3bd73101099493e8c8e46
MD5 70d99e7a53997764e2e236165f75d1ae
BLAKE2b-256 3ff879cd71f74df3de49150bdd19b965fa498c83544a51e5570fd3f4be1f56b6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.13-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 bbbedbb05e4b81a0b4fb5caf276213f2ce72562cf0e2e51c7011c80daefb8c23
MD5 52f5d4054e6d4afb5ef4f66a15a204f6
BLAKE2b-256 de0a81caa5c9464e28e00396f9c8285243cff103f4b75194afc5387df53b858e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.13-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ffdfc975a22b00fd8bd0f082b5dbbb4f1163f3b4a1c7e3e4ea453e48c541c017
MD5 bac2de1b0ba9bc1681fd6033bbb09537
BLAKE2b-256 1803426c7b437ded8bd6c997bddedbc4c84bab94c6ab0487f98bd32a4e7f5d2f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.13-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 c969dec5340e1dfba51a12393f7de046d61a88597ef42b58e8a893ba82b4f4e2
MD5 26305ae46e0f20911477c393ce516ed1
BLAKE2b-256 b232d89f1e2b21eef5018455e076faa91d01b9ee97736c2fc4d678838d8a476a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.13-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 3fc1d433a208afb5fabea325459857316010e977fbca8d79434103ceefe171e2
MD5 02dbf77a6bff338fbe33806addd80810
BLAKE2b-256 8899eefc27e9075fcf66a0c511bf14fe1d7cdd618ef08dcc69af7e14142fc4aa

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.13-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 430fc8cafee00281395b101bdb24ef7454d9e0f68d00ea355c7b70908f16d039
MD5 559a70407286d7c5b086b397dcb5eea7
BLAKE2b-256 83751c0b5215d2e57fe9d704b26a10411fb2743fa20e7c3efef68f5a140fead3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.13-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 f1ce3c707f665979fc417564b64ab965c930d3c8dd365d02500201bad9257f06
MD5 1bc4d0bcf4a49f97a85b1e053b5d3693
BLAKE2b-256 b6e7c6892a2081c476775e8b79b14d52683e2e2586423f37c125e8f78566728b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.13-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 48bb83ba29615b7320a48f7bf73ae6ea6a10a1b1fd01183de49b5269f3dcbc45
MD5 2c79ab43c84b138703029679330a12f3
BLAKE2b-256 df24b1052cf02223e0aef6973f16aafb68333608b98be989dad53e83dd7e3f44

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.13-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 fee78c86d3590fc97c56fff0ee7e9c54e37d03375435acfed2ea873fd0d142f0
MD5 90fee9f7e54d989dad7bc5331e6ea96f
BLAKE2b-256 7e2ea4f7e7d709722e09613effc2c0f0389c9eaf867680ec1003b618d3562d04

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.11.13-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.11.13-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 3e3b8ca824060e48f9cb5d6432cae5a2954aec28835ec2d502d0b186becb2a9e
MD5 47aeba16fee3341eaffdc4b2ea5ab879
BLAKE2b-256 f1cb33c0cf3f2c3d8e3f1181c183c4c39967db9c858de1a94f9bc4733042501b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page