Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming | **TensorDict for parameter serialization | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings. Whenever you need to execute an operation over a batch of tensors, TensorDict is there to help you.

The primary goal of TensorDict is to make your code-bases more readable, compact, and modular. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General principles

Unlike other pytrees, TensorDict carries metadata that make it easy to query the state of the container. The main metadata are the batch_size (also referred as shape), the device, the shared status (is_memmap or is_shared), the dimension names and the lock status.

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device.

Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")

When a tensordict has a device, all write operations will cast the tensor to the TensorDict device:

>>> data["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert data["key 3"].device is torch.device("cuda:0")

Once the device is set, it can be cleared with the clear_device_ method.

TensorDict as a specialized dictionary

TensorDict possesses all the basic features of a dictionary such as clear, copy, fromkeys, get, items, keys, pop, popitem, setdefault, update and values.

But that is not all, you can also store nested values in a tensordict:

>>> data["nested", "key"] = torch.zeros(3, 4) # the batch-size must match

and any nested tuple structure will be unravelled to make it easy to read code and write ops programmatically:

>>> data["nested", ("supernested", ("key",))] = torch.zeros(3, 4) # the batch-size must match
>>> assert (data["nested", "supernested", "key"] == 0).all()
>>> assert (("nested",), "supernested", (("key",),)) in data.keys(include_nested=True)  # this works too!

You can also store non-tensor data in tensordicts:

>>> data = TensorDict({"a-tensor": torch.randn(1, 2)}, batch_size=[1, 2])
>>> data["non-tensor"] = "a string!"
>>> assert data["non-tensor"] == "a string!"

Tensor-like features

[Nightly feature] TensorDict supports many common point-wise arithmetic operations such as == or +, += and similar (provided that the underlying tensors support the said operation):

>>> td = TensorDict.fromkeys(["a", "b", "c"], 0)
>>> td += 1
>>> assert (td==1).all()

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = data[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(data.view(-1))
torch.Size([12])
>>> print(data.reshape(-1))
torch.Size([12])
>>> print(data.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = data.apply(lambda tensor: tensor.uniform_())

apply() can also be great to filter a tensordict, for instance:

data = TensorDict({"a": torch.tensor(1.0, dtype=torch.float), "b": torch.tensor(1, dtype=torch.int64)}, [])
data_float = data.apply(lambda x: x if x.dtype == torch.float else None) # contains only the "a" key
assert "b" not in data_float

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> params = TensorDict.from_module(model)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> with params.to_module(model):
...     out = model(x)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> def func(x, params):
...     with params.to_module(model):
...         return model(x)
>>> y = vmap(func, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and (soon) torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

TensorDict for parameter serialization and building datasets

TensorDict offers an API for parameter serialization that can be >3x faster than regular calls to torch.save(state_dict). Moreover, because tensors will be saved independently on disk, you can deserialize your checkpoint on an arbitrary slice of the model.

>>> model = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3))
>>> params = TensorDict.from_module(model)
>>> params.memmap("/path/to/saved/folder/", num_threads=16)  # adjust num_threads for speed
>>> # load params
>>> params = TensorDict.load_memmap("/path/to/saved/folder/", num_threads=16)
>>> params.to_module(model)  # load onto model
>>> params["0"].to_module(model[0])  # load on a slice of the model
>>> # in the latter case we could also have loaded only the slice we needed
>>> params0 = TensorDict.load_memmap("/path/to/saved/folder/0", num_threads=16)
>>> params0.to_module(model[0])  # load on a slice of the model

The same functionality can be used to access data in a dataset stored on disk. Soring a single contiguous tensor on disk accessed through the tensordict.MemoryMappedTensor primitive and reading slices of it is not only much faster than loading single files one at a time but it's also easier and safer (because there is no pickling or third-party library involved):

# allocate memory of the dataset on disk
data = TensorDict({
    "images": torch.zeros((128, 128, 3), dtype=torch.uint8),
    "labels": torch.zeros((), dtype=torch.int)}, batch_size=[])
data = data.expand(1000000)
data = data.memmap_like("/path/to/dataset")
# ==> Fill your dataset here
# Let's get 3 items of our dataset:
data[torch.tensor([1, 10000, 500000])]  # This is much faster than loading the 3 images independently

Preprocessing with TensorDict.map

Preprocessing huge contiguous (or not!) datasets can be done via TensorDict.map which will dispatch a task to various workers:

import torch
from tensordict import TensorDict, MemoryMappedTensor
import tempfile

def process_data(data):
    images = data.get("images").flip(-2).clone()
    labels = data.get("labels") // 10
    # we update the td inplace
    data.set_("images", images)  # flip image
    data.set_("labels", labels)  # cluster labels

if __name__ == "__main__":
    # create data_preproc here
    data_preproc = data.map(process_data, num_workers=4, chunksize=0, pbar=True)  # process 1 images at a time

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    data = TensorDict({}, batch_size=[])
    data["a"] = torch.randn(3)
    data["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return data

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

data = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

data = TensorDict({}, batch_size=[N])
for i in range(N):
    data[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = data.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = data["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.4.23-cp311-cp311-win_amd64.whl (292.3 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.4.23-cp311-cp311-macosx_10_9_universal2.whl (351.6 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.4.23-cp310-cp310-win_amd64.whl (291.5 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.4.23-cp310-cp310-macosx_10_15_x86_64.whl (293.4 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.4.23-cp39-cp39-win_amd64.whl (291.7 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.4.23-cp39-cp39-macosx_11_0_x86_64.whl (293.6 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.4.23-cp38-cp38-win_amd64.whl (291.5 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.4.23-cp38-cp38-macosx_11_0_x86_64.whl (293.3 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.4.23-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 bf7eed7b92623b7533224b545ebce318b27cb9c810ee68d239b90012fc19d56b
MD5 0966c7996b3347fe31617c31ab61b7bf
BLAKE2b-256 143aa3840789f43d1868ae60cc24df24454ac38f00f6e49b2056ceea37180d87

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.23-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 daa878abc9fb323510ecd73b99ab586d9d9f0347ef4f05bcd4a822edce8212f1
MD5 a68644b2fa21faf060e00d1df3303d34
BLAKE2b-256 36e1af3a7fa9fab1a37ba2dd9ca4fc3916a3a29d91bf4e1ce7c129e4f28a6bcd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.23-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 0f92c9d818dd0378dd31b0c8a6f87650716c79f17dd290e9b02424d3d8924422
MD5 80f37359d605d04e68962341155ba226
BLAKE2b-256 58a8855abfeaa24a3a4d11aca60c7ce5215754c9f2ca4c4640a337a8ad8f2bf5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.23-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 22eac12a7cbc1d8f544d16a881351118e03268c07da13f757cfba14170a1576c
MD5 8e30614530f4ecc436303ec2931a75ef
BLAKE2b-256 ea3092e73b7ba136dd398fe3a0cab9b7f082f4626557baac79681e011ddd7c42

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.23-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 a640972482804f3b570298aa9de736197d13bf3102c57c9be6a3d61ff8066bd2
MD5 443f146294c12d5a9dfccfc5561d497d
BLAKE2b-256 3ff657abfd2d9460697d425a3d24f0552b0b352a166f377829f2d8b5a6c6843d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.23-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 9dbe935d1c70f0fec3ed77ebfd15c87c2772e5a6d466344d5a019ce45303c3bf
MD5 51dee9ba46f66ad367eb2afecee1095b
BLAKE2b-256 352cb31e385843887319218d584ee63c1a7271edd3c10948df0afc29841fc725

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.23-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 deb655a8e0ef5b205b9fff778da4272c4b6d76d30077bf78236a689f67fdc239
MD5 0c5c70a875bd0bc9c579bf094bba6832
BLAKE2b-256 24a2bc9ffafdd6c498e3820ca02b10dd65b597bab3e3768ca4c19116d7980c01

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.23-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 0b5bb76b7ea92e09f6bc303ceed1902b25abb439b1f6e83a4c62a06e9100c9f4
MD5 5baf290a5fd627ff1d359012491a7fac
BLAKE2b-256 7e3e947f20b31bdbd6f85f5784d2ebd734acad609b4b0b784e15d66acfa3a27c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.23-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 4cd8b0a4242bfd58df1077a7dd988ad5e3007d7afef582ba095613a450d1cd68
MD5 16929354942c0cf4541a4ba2b605ea0c
BLAKE2b-256 40fa9a9cc9edce38adee0a165c094844b87c3da3c931fab2c88c5e1a57f47354

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.23-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 9c5bafd0cc368300a5454005a19b8e9ff011802330dd7e2b971a2d41c08ebe64
MD5 e677ee8841dd456beb55ab68b619805e
BLAKE2b-256 4817c680199c3ed073007aa6ce9d1ed21d9228c1c793230fdf207be2fe1976e0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.23-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 206f74137e81a7dea140e35baa6bbf501ed92f42e46aa279a543003db4d5cb1a
MD5 6c76bec1db73d182e5c10ba72f65f48e
BLAKE2b-256 9067c9a70059d341cbae0f6ca9b1436957e554beaefd42764e18711bce14aa74

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.23-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.23-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 28df96277c079921a3b72ed563a843d120b2c4bdb9234ca1cecfc09bebeba14e
MD5 1addc1559d72a98a05ab309f90e60ee7
BLAKE2b-256 f176977a18273d3282d7d6fce01aadeb621cd821550bbf82a0e8d550248c65dd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page