Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming | **TensorDict for parameter serialization | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings. Whenever you need to execute an operation over a batch of tensors, TensorDict is there to help you.

The primary goal of TensorDict is to make your code-bases more readable, compact, and modular. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General principles

Unlike other pytrees, TensorDict carries metadata that make it easy to query the state of the container. The main metadata are the batch_size (also referred as shape), the device, the shared status (is_memmap or is_shared), the dimension names and the lock status.

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device.

Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")

When a tensordict has a device, all write operations will cast the tensor to the TensorDict device:

>>> data["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert data["key 3"].device is torch.device("cuda:0")

Once the device is set, it can be cleared with the clear_device_ method.

TensorDict as a specialized dictionary

TensorDict possesses all the basic features of a dictionary such as clear, copy, fromkeys, get, items, keys, pop, popitem, setdefault, update and values.

But that is not all, you can also store nested values in a tensordict:

>>> data["nested", "key"] = torch.zeros(3, 4) # the batch-size must match

and any nested tuple structure will be unravelled to make it easy to read code and write ops programmatically:

>>> data["nested", ("supernested", ("key",))] = torch.zeros(3, 4) # the batch-size must match
>>> assert (data["nested", "supernested", "key"] == 0).all()
>>> assert (("nested",), "supernested", (("key",),)) in data.keys(include_nested=True)  # this works too!

You can also store non-tensor data in tensordicts:

>>> data = TensorDict({"a-tensor": torch.randn(1, 2)}, batch_size=[1, 2])
>>> data["non-tensor"] = "a string!"
>>> assert data["non-tensor"] == "a string!"

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = data[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(data.view(-1))
torch.Size([12])
>>> print(data.reshape(-1))
torch.Size([12])
>>> print(data.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = data.apply(lambda tensor: tensor.uniform_())

apply() can also be great to filter a tensordict, for instance:

data = TensorDict({"a": torch.tensor(1.0, dtype=torch.float), "b": torch.tensor(1, dtype=torch.int64)}, [])
data_float = data.apply(lambda x: x if x.dtype == torch.float else None) # contains only the "a" key
assert "b" not in data_float

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> params = TensorDict.from_module(model)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> with params.to_module(model):
...     out = model(x)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> def func(x, params):
...     with params.to_module(model):
...         return model(x)
>>> y = vmap(func, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and (soon) torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

TensorDict for parameter serialization and building datasets

TensorDict offers an API for parameter serialization that can be >3x faster than regular calls to torch.save(state_dict). Moreover, because tensors will be saved independently on disk, you can deserialize your checkpoint on an arbitrary slice of the model.

>>> model = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3))
>>> params = TensorDict.from_module(model)
>>> params.memmap("/path/to/saved/folder/", num_threads=16)  # adjust num_threads for speed
>>> # load params
>>> params = TensorDict.load_memmap("/path/to/saved/folder/", num_threads=16)
>>> params.to_module(model)  # load onto model
>>> params["0"].to_module(model[0])  # load on a slice of the model
>>> # in the latter case we could also have loaded only the slice we needed
>>> params0 = TensorDict.load_memmap("/path/to/saved/folder/0", num_threads=16)
>>> params0.to_module(model[0])  # load on a slice of the model

The same functionality can be used to access data in a dataset stored on disk. Soring a single contiguous tensor on disk accessed through the tensordict.MemoryMappedTensor primitive and reading slices of it is not only much faster than loading single files one at a time but it's also easier and safer (because there is no pickling or third-party library involved):

# allocate memory of the dataset on disk
data = TensorDict({
    "images": torch.zeros((128, 128, 3), dtype=torch.uint8),
    "labels": torch.zeros((), dtype=torch.int)}, batch_size=[])
data = data.expand(1000000)
data = data.memmap_like("/path/to/dataset")
# ==> Fill your dataset here
# Let's get 3 items of our dataset:
data[torch.tensor([1, 10000, 500000])]  # This is much faster than loading the 3 images independently

Preprocessing with TensorDict.map

Preprocessing huge contiguous (or not!) datasets can be done via TensorDict.map which will dispatch a task to various workers:

import torch
from tensordict import TensorDict, MemoryMappedTensor
import tempfile

def process_data(data):
    images = data.get("images").flip(-2).clone()
    labels = data.get("labels") // 10
    # we update the td inplace
    data.set_("images", images)  # flip image
    data.set_("labels", labels)  # cluster labels

if __name__ == "__main__":
    # create data_preproc here
    data_preproc = data.map(process_data, num_workers=4, chunksize=0, pbar=True)  # process 1 images at a time

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    data = TensorDict({}, batch_size=[])
    data["a"] = torch.randn(3)
    data["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return data

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

data = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

data = TensorDict({}, batch_size=[N])
for i in range(N):
    data[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = data.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = data["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.4.2-cp311-cp311-win_amd64.whl (288.0 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.4.2-cp311-cp311-macosx_10_9_universal2.whl (346.8 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.4.2-cp310-cp310-win_amd64.whl (286.6 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.4.2-cp310-cp310-macosx_10_15_x86_64.whl (288.6 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.4.2-cp39-cp39-win_amd64.whl (286.9 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.4.2-cp39-cp39-macosx_11_0_x86_64.whl (288.7 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.4.2-cp38-cp38-win_amd64.whl (286.7 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.4.2-cp38-cp38-macosx_11_0_x86_64.whl (288.5 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.4.2-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 6d3c53aff905e8bd116b9735b50469b95597451bc9e47fef1bf431b2b5989569
MD5 e597a1dc1d94fde5f5240184d1e88d0b
BLAKE2b-256 e152d81484b318a9faa1c9ee83db0cf1f87b465e1e64ba89892ac4bf7b074a65

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.2-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 3dfd653a100a535de56aca9f66c877445fd3a0fe905ff7e785aabf1b9d068d28
MD5 5272fd36f6355a4809091d32b3bf74a2
BLAKE2b-256 b4d86a008443e56119cfdd8e41b3e8026c407bbceef47a3b62ea1731cdeb8de4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.2-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 c0fb873e97b02bf90b4d1ec0fda8f68bed41d14403077ff62a41581401dece38
MD5 42e6a39c2b8d1c83c80c50eb40bb6dc2
BLAKE2b-256 891f2667a4d9b1f3421ccbc50e57b11784acbfbdf3d6771cce75668c73481080

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.2-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 9e1b435a9e1d22727d5c8a7ab1229fe3ba2443a7b9a9bca3b0e6b3ce83f49074
MD5 dcf2b0b105a654232c6a566bcc0ce94d
BLAKE2b-256 8e3ffb4245066d1b50354bf2db10ab378a57f201436c3b907a34ffe9e3c79a71

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.2-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 452cb31a6d441912ffa396c55fbb14d7fd15055cd0667145aa42f53dc8fe7622
MD5 98b485658beac5c387ec7056962277bb
BLAKE2b-256 f90c7739764d8fda2361a6a66f3b7465b128d0277d9eb38357400a8d97c12e61

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.2-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 06a28c4983d6cbe1f5195d98c0595679cfdc78da8fb35d3078a7f0ebdb1fc273
MD5 036f1e31fc99cf1fe9c7af9f3d8dbd3b
BLAKE2b-256 3694da7641a0ca5492aecf8ddf8fdda5f9d790ec658141a2539677a9789bc8bf

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.2-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 56dd2038f78b7fb10f06bdaee9a91a46fefc0101ed716ed32103f2f2d63b379a
MD5 c00eb7b8eb8b8726b924d1f37e669897
BLAKE2b-256 a3ad2a558194eee87e159cdf3db3e5145a132ad33f155edb92c2703cbde6e59b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.2-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 25ddb3f82293105c5af9767e398075d1b10ec065c855f3d614bb648e3e418d2b
MD5 ff64462c72d9e2771f984a0101f953a9
BLAKE2b-256 e5e462481e914326294108de54268a40df0e7a5dbc768ebfb6b15630eb99d9b3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.2-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 d023624a0807931637f741b21aefcdc021a48404b1a0ba62b513511c3fac86fe
MD5 61eff035a6186d6b76ef1236f2d26de4
BLAKE2b-256 2dedb237d0315987a4d98c943800f537d08dad4a4c1ab3cf5ccc8d11e88b5bcc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.2-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 3b436537491c241a89323d441ac676ce3eb7bc307cb39f6f8c2eff11852a6a1a
MD5 3e656994943a2746a87290796244482a
BLAKE2b-256 6276e997553250f8766f1f823dffbbfe3c817c9d4b134a30cb7d2899437e46aa

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.2-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 67e60e395b8fac7bea0a99221a59241bf0f344184080aee0067211efbafade2b
MD5 9b3a8a5d621ce7dcfa463c457f86fb2f
BLAKE2b-256 2dd911693e774101b02165dc95458d6c632ea8a1db21779623ca3d48485b4b79

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.2-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.2-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 67a58dce058327fe3e9634c8fab91e5a937de6a1434a0ca9d2311b45c13c7418
MD5 9bfea7da5bce1057c8516a7a8747bcdd
BLAKE2b-256 0dda5c90a1c005813827955e864cd55d81df8b2ec26902eb82d9777cc71d63d5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page