Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming | **TensorDict for parameter serialization | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> data["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert data["key 3"].device is torch.device("cuda:0")

But that is not all, you can also store nested values in a tensordict:

>>> data["nested", "key"] = torch.zeros(3, 4) # the batch-size must match

and any nested tuple structure will be unravelled to make it easy to read code and write ops programmatically:

>>> data["nested", ("supernested", ("key",))] = torch.zeros(3, 4) # the batch-size must match
>>> assert (data["nested", "supernested", "key"] == 0).all()
>>> assert (("nested",), "supernested", (("key",),)) in data.keys(include_nested=True)  # this works too!

You can also store non-tensor data in tensordicts:

>>> data = TensorDict({"a-tensor": torch.randn(1, 2)}, batch_size=[1, 2])
>>> data["non-tensor"] = "a string!"
>>> assert data["non-tensor"] == "a string!"

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = data[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(data.view(-1))
torch.Size([12])
>>> print(data.reshape(-1))
torch.Size([12])
>>> print(data.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = data.apply(lambda tensor: tensor.uniform_())

apply() can also be great to filter a tensordict, for instance:

data = TensorDict({"a": torch.tensor(1.0, dtype=torch.float), "b": torch.tensor(1, dtype=torch.int64)}, [])
data_float = data.apply(lambda x: x if x.dtype == torch.float else None) # contains only the "a" key
assert "b" not in data_float

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> params = TensorDict.from_module(model)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> with params.to_module(model):
...     out = model(x)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> def func(x, params):
...     with params.to_module(model):
...         return model(x)
>>> y = vmap(func, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and (soon) torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

TensorDict for parameter serialization and building datasets

TensorDict offers an API for parameter serialization that can be >3x faster than regular calls to torch.save(state_dict). Moreover, because tensors will be saved independently on disk, you can deserialize your checkpoint on an arbitrary slice of the model.

>>> model = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3))
>>> params = TensorDict.from_module(model)
>>> params.memmap("/path/to/saved/folder/", num_threads=16)  # adjust num_threads for speed
>>> # load params
>>> params = TensorDict.load_memmap("/path/to/saved/folder/", num_threads=16)
>>> params.to_module(model)  # load onto model
>>> params["0"].to_module(model[0])  # load on a slice of the model
>>> # in the latter case we could also have loaded only the slice we needed
>>> params0 = TensorDict.load_memmap("/path/to/saved/folder/0", num_threads=16)
>>> params0.to_module(model[0])  # load on a slice of the model

The same functionality can be used to access data in a dataset stored on disk. Soring a single contiguous tensor on disk accessed through the tensordict.MemoryMappedTensor primitive and reading slices of it is not only much faster than loading single files one at a time but it's also easier and safer (because there is no pickling or third-party library involved):

# allocate memory of the dataset on disk
data = TensorDict({
    "images": torch.zeros((128, 128, 3), dtype=torch.uint8),
    "labels": torch.zeros((), dtype=torch.int)}, batch_size=[])
data = data.expand(1000000)
data = data.memmap_like("/path/to/dataset")
# ==> Fill your dataset here
# Let's get 3 items of our dataset:
data[torch.tensor([1, 10000, 500000])]  # This is much faster than loading the 3 images independently

Preprocessing with TensorDict.map

Preprocessing huge contiguous (or not!) datasets can be done via TensorDict.map which will dispatch a task to various workers:

import torch
from tensordict import TensorDict, MemoryMappedTensor
import tempfile

def process_data(data):
    images = data.get("images").flip(-2).clone()
    labels = data.get("labels") // 10
    # we update the td inplace
    data.set_("images", images)  # flip image
    data.set_("labels", labels)  # cluster labels

if __name__ == "__main__":
    # create data_preproc here
    data_preproc = data.map(process_data, num_workers=4, chunksize=0, pbar=True)  # process 1 images at a time

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    data = TensorDict({}, batch_size=[])
    data["a"] = torch.randn(3)
    data["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return data

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

data = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

data = TensorDict({}, batch_size=[N])
for i in range(N):
    data[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = data.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = data["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.3.12-cp311-cp311-win_amd64.whl (278.9 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.3.12-cp311-cp311-macosx_10_9_universal2.whl (337.9 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.3.12-cp310-cp310-win_amd64.whl (278.0 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.3.12-cp310-cp310-macosx_10_15_x86_64.whl (279.7 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.3.12-cp39-cp39-win_amd64.whl (277.6 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.3.12-cp39-cp39-macosx_11_0_x86_64.whl (279.8 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.3.12-cp38-cp38-win_amd64.whl (277.9 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.3.12-cp38-cp38-macosx_11_0_x86_64.whl (279.6 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.3.12-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 b167f4305d2eb6d3a9008b9b237d34e7b650309c0b0b5a26104a5ceb03a69a74
MD5 a3e3e308d169b9955549f4c8ac5d85c2
BLAKE2b-256 f03d0a8fea1914b547a4250d9f192da4fb0e2fb004f9fead90a918e268fcd869

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.3.12-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7b1279dab3344732857b39860d289a731b7c8440a0bcc3d1413c8397fdaf9775
MD5 91b0e50ed1f28cac36a4df1c63f8798a
BLAKE2b-256 59a3a10ab2594403eafc1226f6fec2f74435279053a8aeca746d618c7639b67f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.3.12-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 1970eb797e474cf52f5215bc270069736781cabeb1cfe1aaa0bbfc2afdd90ba0
MD5 ed82e536cbd35b407e9580f7bca12eb2
BLAKE2b-256 930f7b6569c68c00c832243a548c2a383b9c7742687417279a8fc6c23bc196f6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.3.12-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 2d2643e9a6cb5a73e971b6ea72a96852c3f3696e7bc9e8664ac56ecfdd743f43
MD5 9f3ac95015e478028df7f8769480e280
BLAKE2b-256 903f38aa9ac225a0f090b0c35ed8c4caf52b991f4874f915ed0349b8b424ee89

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.3.12-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d08dc9a24a8fbe9b9fb6f80b66573e88f69a619543b24b493f745b4c369000a9
MD5 cc411ce451484e289035c56484ddb94c
BLAKE2b-256 524ef64f9157314957c4a6b1b36f45449f10e6248f3646be3a476c1d14bf672d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.3.12-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 85b6f9be080b27e4eb69502e70e7b95faeab75e5a43967056ae4c3458bfe9979
MD5 41a2a67778f0be78d6f89607833c3812
BLAKE2b-256 6b9327f8e5c0a0c3e57b1315619493606806512540c52c9521d365879e286700

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.3.12-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 ef924a8c617132028a38b220115a5d0be17abf85a0602e2d0d2804c4702f04fc
MD5 f334e46ad2a6b5b5f006e9b021d153e0
BLAKE2b-256 5b391c84ee7c1c05f49e26211e3668a8ff72d788887dc832b66dd659a24cc84f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.3.12-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c70490d7d12341255bd5144bdac31b715ad286a4cf716c818051046bb6e09e99
MD5 64cc8e5e94b5d4c0e0c36935d1c59145
BLAKE2b-256 2fb3fd180cdf6512e9aaa60a98d3cedf9f67c26649744c512f6dd26c1f5010f2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.3.12-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 56ec60e60092127dfd370ee422d71315d7dfc79f7be3657d1bbcaed7c41c44fd
MD5 e9900db60dfcc7689e88ff63cb801354
BLAKE2b-256 f1e414684e5f0f523903b2f52cb3726f8d2675ac324bf87ce73c0e48878b1226

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.3.12-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 cc35eed525342b8db064e45bf9211681f4880c129953d1bef17e80dac737c2d0
MD5 26e6de67023489abb01f4fee174da6fb
BLAKE2b-256 a66e94230578a003cdfe9259f6c6bdcdf9493a347dc78857206fb14ea34e76c2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.3.12-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 88ca97d0a8bded92bb66ad857fb790c76eec1c78e66c780626330e3ec3bc0400
MD5 04888f0bcf3cd365848bc4b324521c79
BLAKE2b-256 2295708c6e4902641e5687571b83cb743ffc82e84b31c422509b52d0df2b3d41

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.3.12-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.3.12-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 0eab9f63acb686d18e5b828ed96702f2dbb3ab48e312a00f53224b4b8bc0b540
MD5 d29215ed3d69eb1f83ed82674bcfaff6
BLAKE2b-256 2b2e2d4874b25b1a0e0089e1c0f608544bf40304c483f183af7a282f17e62c48

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page