Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.1.22-cp311-cp311-win_amd64.whl (259.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.1.22-cp311-cp311-macosx_10_9_universal2.whl (319.1 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.1.22-cp310-cp310-win_amd64.whl (259.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.1.22-cp310-cp310-macosx_10_15_x86_64.whl (260.9 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.1.22-cp39-cp39-win_amd64.whl (258.9 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.1.22-cp39-cp39-macosx_11_0_x86_64.whl (261.0 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.1.22-cp38-cp38-win_amd64.whl (259.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.1.22-cp38-cp38-macosx_11_0_x86_64.whl (260.8 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.1.22-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 d0e366f3de1b9993fb9eeda49d147048d42f83257e5142b5b0daf146a7013381
MD5 bf867933104622e26c69eb2ee2b8f218
BLAKE2b-256 a26c39a302f699755ef0bea290356712fe565b2a6ac7cdeac723fb5a799ff8b0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.22-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b730b0105ce39a134561db4048b25ad747dde7ea917d038adaa7ecdeeff68887
MD5 fb062ee943d05bccfa7dc65f43cbe149
BLAKE2b-256 590bb0e539d6165b40081637cf5f3c08f73c74a5b03997ee5dfbdcd68c69b045

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.22-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 1af09cfef2fc975e7666917f2cf885e81fd84333c736ecff9a5c4cee3d72e2bf
MD5 999f414166ad23ffdab8f445421ae325
BLAKE2b-256 cac88e89cb659920836c1ea7122c6ec4f37f414df504d5442f802ff9e93dd158

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.22-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 fa01b23c3e21cfc864bd4e508c180cbb234e797320205263fa971ca9a166ae6e
MD5 d4322181949f3097d9a05d84c367b998
BLAKE2b-256 d581edf719bb2667614c39a33d08609ede0b07cc4484076b103bfecf513d1bfb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.22-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 72bab8180b574d0194862f717c1b889e3fe887e37e1c3d5f2c56ec1e285a460b
MD5 f52c4ac058d17ea53c2956961ecd8ceb
BLAKE2b-256 0ad6125d4d45995bf3078478b68738afcd7bdc9e21e6db952d5f641e57edd706

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.22-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 4ec1f98d009680107301982d2f36b83ee260e9e9d15159db1131ea3ef0a3462e
MD5 baeb8658417e8d483dc27f7fec17482b
BLAKE2b-256 a072c69b97ff9f5dec2ce387f1b5f0e3bf0ff9807a7b6975670bb2a7f0971287

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.22-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 314bc7371a82b3eee5afaf81bf8fa8c159192181bdca848462125015d312a4f0
MD5 3c1ec312ca236bc59cba04ea669668d2
BLAKE2b-256 02f5711e7612e157b26dfdda77751068a8f8eabd9c21ff85d249fab77474cfd9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.22-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 028c8a275f4f4392afa6c26f8871f3778bf9cca16515f5f8e938c9c082ecfbe4
MD5 05a5d5d737a12a880945d1cfa4992168
BLAKE2b-256 d365fa2e59033b741ebe53ccbb59bb18c9d689c3535d06c32c1bca6195c99210

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.22-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 46eba17cf3b13fdf42545a5863a8136847ffba9f4ea21897ba5242e3daedd009
MD5 08544da8ec09e4125befcf5b56c9cc08
BLAKE2b-256 8597a67a947201371a0e0e488cba25b5253b288f3f3983c61a7a031f4eb6b0c6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.22-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 736e22a95185a51ff8e19993ca09343f664314256369e2c517b10c793233b8c4
MD5 b1a271e97a58e3caeb3eaf8238f3c896
BLAKE2b-256 7209638076d976de547796268604eab7af479142f011a6fe43e38f0a7e349ec2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.22-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2c40bb3fa8688abb05dddd7e96e9a04bd089f9c12dd3ea426d8d73c8fa462605
MD5 08d58d0e1ccec46db33812d33e75dad6
BLAKE2b-256 124511e13b8f472ef87296b3ebc6dcfd826b13344da384a4c54ca9e69f0cacee

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.1.22-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.1.22-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 ffc4c30250c53bbec419cf1b663850c82e73f25d26992c94b0f6089b9ce4115b
MD5 7ea25f0a0e5fee0cc7fc91db5603aa79
BLAKE2b-256 64d409282086a56ace02de857bbcbae1e5114f7d87f7999db8b0596cbeed3794

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page