Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2023.12.21-cp311-cp311-win_amd64.whl (253.9 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2023.12.21-cp311-cp311-macosx_10_9_universal2.whl (313.2 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2023.12.21-cp310-cp310-win_amd64.whl (253.4 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2023.12.21-cp310-cp310-macosx_10_15_x86_64.whl (255.0 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2023.12.21-cp39-cp39-win_amd64.whl (253.0 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2023.12.21-cp39-cp39-macosx_11_0_x86_64.whl (255.1 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2023.12.21-cp38-cp38-win_amd64.whl (253.3 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2023.12.21-cp38-cp38-macosx_11_0_x86_64.whl (254.9 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2023.12.21-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 0a9dfd4cd6967e64b676d49a4261b896efbdcc46af967275bc35480b2b04166b
MD5 946c69f9d2f7c14489440c597072da56
BLAKE2b-256 29c61cdfa8b1f2b1ea8e79a156abe7af039212c50598b062bdbc3886a0914134

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.21-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 afb402c97b4d95bc7c91040662ca53f6745c67f17e7e03c3c2ce28afb02a243c
MD5 b20dcf155542ea065f14c8536fd4c7aa
BLAKE2b-256 a6b08966503c945a4534976a71a695f6a0e1ac9130607faed58959b668232b55

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.21-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 da3a14b66b61aad54cdaeb983956ba390fd8ae9f4927c2bc359fdd1b32ba39b9
MD5 cf6230f36c4cc3b896fb4a95d752bca2
BLAKE2b-256 e31199a29fb528991107c30fa2a6a273e6cdca6d174dab243bd78f06008f6b3b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.21-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 53bf456b2eadcd1ba7aaf152b71ee047e0d725ef381d93006c37c31eda52f0b5
MD5 45f3326f94ed60b6594250c056f21dc5
BLAKE2b-256 971c9077ea17b5794b8d01586c1e485e3cc11abbfa640b36379086f1820988b7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.21-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 bf78faaa62b826fb62f8fef01ab08ccd3584291be757232e77addc9d0094ed05
MD5 69538d7e993f95ede5293acbe2a95446
BLAKE2b-256 6d815b6b6694d80b6d738108824fa910d113bb067878bdb8c3ef9b0e7eecefe3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.21-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 49995466eba563b50d96a3e97fbb288198baec2a24befe7d1e1bee85012a44b4
MD5 1f24b042769b6b5eeb94a9b12ef1102f
BLAKE2b-256 e1f8c4c8f293a51f9c1112fe9b184156033977e2f51f3b4e8829a0f79335c95b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.21-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 9305a3b9be5e81752546b85b83a556012bcdb080af68217a20fd2c882ab2a6fb
MD5 01670b360b4b7bd8c214e68232496c8f
BLAKE2b-256 4e052e37f5b37bf7e5d9b4ce84a4d91bab984b412355c09bdd6dce9cd6add2a9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.21-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 6b5468e498d2ae3f1a188b271b79d84d3d096e8acbdccade6a9c2aa31c8f3ed0
MD5 44e2475091f50fd8fd0d71ad4e71d59a
BLAKE2b-256 130ce808e72f6ed0da92be12443d123ccc31cd51bd8e30313fc927a6ab0ccffb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.21-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 90fd04aa666348e7f894ec38f689dba6c7a956ab9dff090277cfada6d4d48175
MD5 d782b7db4f6778258308cc28ab491c61
BLAKE2b-256 ff9b0a8421a6b0f2976f12af2a00e57591f8ca4d662868764c1cc270ba5dd8b7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.21-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 589ce65de6ef80c23d9702c2b0715d5b9e3b76b58cac565a72a7f69b4287d263
MD5 7ad6045c051ad29e3f18a151af12cc5a
BLAKE2b-256 3fbf9d7d605ab7e6697f5c16c98b60df75d24732c12aeee71faf16f664bf06c1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.21-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2bc0981aedebb4c4f2f4e113658d73cad094edfc5cfbddb2aa42fcf448619462
MD5 cec365d3067b69d0d06838357d592e59
BLAKE2b-256 539445bb271f0f0e873a1dd487c9fb2a7831238e7a5dd046eaab6b1f57fce74d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2023.12.21-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2023.12.21-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 b88c0a522f1ab7fadb97968917c70255fe59ea20e96766a55cf0b4395d834bf5
MD5 689bb67bf9ce54b891084e2a2eab3044
BLAKE2b-256 1e4db213bfe67953bb6972562cdcf1371ca656b4c3b9de4c3c1828f7bbeecde1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page