Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming | **TensorDict for parameter serialization | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings. Whenever you need to execute an operation over a batch of tensors, TensorDict is there to help you.

The primary goal of TensorDict is to make your code-bases more readable, compact, and modular. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General principles

Unlike other pytrees, TensorDict carries metadata that make it easy to query the state of the container. The main metadata are the batch_size (also referred as shape), the device, the shared status (is_memmap or is_shared), the dimension names and the lock status.

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device.

Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")

When a tensordict has a device, all write operations will cast the tensor to the TensorDict device:

>>> data["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert data["key 3"].device is torch.device("cuda:0")

Once the device is set, it can be cleared with the clear_device_ method.

TensorDict as a specialized dictionary

TensorDict possesses all the basic features of a dictionary such as clear, copy, fromkeys, get, items, keys, pop, popitem, setdefault, update and values.

But that is not all, you can also store nested values in a tensordict:

>>> data["nested", "key"] = torch.zeros(3, 4) # the batch-size must match

and any nested tuple structure will be unravelled to make it easy to read code and write ops programmatically:

>>> data["nested", ("supernested", ("key",))] = torch.zeros(3, 4) # the batch-size must match
>>> assert (data["nested", "supernested", "key"] == 0).all()
>>> assert (("nested",), "supernested", (("key",),)) in data.keys(include_nested=True)  # this works too!

You can also store non-tensor data in tensordicts:

>>> data = TensorDict({"a-tensor": torch.randn(1, 2)}, batch_size=[1, 2])
>>> data["non-tensor"] = "a string!"
>>> assert data["non-tensor"] == "a string!"

Tensor-like features

[Nightly feature] TensorDict supports many common point-wise arithmetic operations such as == or +, += and similar (provided that the underlying tensors support the said operation):

>>> td = TensorDict.fromkeys(["a", "b", "c"], 0)
>>> td += 1
>>> assert (td==1).all()

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = data[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(data.view(-1))
torch.Size([12])
>>> print(data.reshape(-1))
torch.Size([12])
>>> print(data.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = data.apply(lambda tensor: tensor.uniform_())

apply() can also be great to filter a tensordict, for instance:

data = TensorDict({"a": torch.tensor(1.0, dtype=torch.float), "b": torch.tensor(1, dtype=torch.int64)}, [])
data_float = data.apply(lambda x: x if x.dtype == torch.float else None) # contains only the "a" key
assert "b" not in data_float

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> params = TensorDict.from_module(model)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> with params.to_module(model):
...     out = model(x)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> def func(x, params):
...     with params.to_module(model):
...         return model(x)
>>> y = vmap(func, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and (soon) torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

TensorDict for parameter serialization and building datasets

TensorDict offers an API for parameter serialization that can be >3x faster than regular calls to torch.save(state_dict). Moreover, because tensors will be saved independently on disk, you can deserialize your checkpoint on an arbitrary slice of the model.

>>> model = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3))
>>> params = TensorDict.from_module(model)
>>> params.memmap("/path/to/saved/folder/", num_threads=16)  # adjust num_threads for speed
>>> # load params
>>> params = TensorDict.load_memmap("/path/to/saved/folder/", num_threads=16)
>>> params.to_module(model)  # load onto model
>>> params["0"].to_module(model[0])  # load on a slice of the model
>>> # in the latter case we could also have loaded only the slice we needed
>>> params0 = TensorDict.load_memmap("/path/to/saved/folder/0", num_threads=16)
>>> params0.to_module(model[0])  # load on a slice of the model

The same functionality can be used to access data in a dataset stored on disk. Soring a single contiguous tensor on disk accessed through the tensordict.MemoryMappedTensor primitive and reading slices of it is not only much faster than loading single files one at a time but it's also easier and safer (because there is no pickling or third-party library involved):

# allocate memory of the dataset on disk
data = TensorDict({
    "images": torch.zeros((128, 128, 3), dtype=torch.uint8),
    "labels": torch.zeros((), dtype=torch.int)}, batch_size=[])
data = data.expand(1000000)
data = data.memmap_like("/path/to/dataset")
# ==> Fill your dataset here
# Let's get 3 items of our dataset:
data[torch.tensor([1, 10000, 500000])]  # This is much faster than loading the 3 images independently

Preprocessing with TensorDict.map

Preprocessing huge contiguous (or not!) datasets can be done via TensorDict.map which will dispatch a task to various workers:

import torch
from tensordict import TensorDict, MemoryMappedTensor
import tempfile

def process_data(data):
    images = data.get("images").flip(-2).clone()
    labels = data.get("labels") // 10
    # we update the td inplace
    data.set_("images", images)  # flip image
    data.set_("labels", labels)  # cluster labels

if __name__ == "__main__":
    # create data_preproc here
    data_preproc = data.map(process_data, num_workers=4, chunksize=0, pbar=True)  # process 1 images at a time

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    data = TensorDict({}, batch_size=[])
    data["a"] = torch.randn(3)
    data["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return data

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

data = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

data = TensorDict({}, batch_size=[N])
for i in range(N):
    data[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = data.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = data["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.4.13-cp311-cp311-win_amd64.whl (290.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.4.13-cp311-cp311-macosx_10_9_universal2.whl (350.0 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.4.13-cp310-cp310-win_amd64.whl (289.9 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.4.13-cp310-cp310-macosx_10_15_x86_64.whl (291.8 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.4.13-cp39-cp39-win_amd64.whl (290.2 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.4.13-cp39-cp39-macosx_11_0_x86_64.whl (291.9 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.4.13-cp38-cp38-win_amd64.whl (289.9 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.4.13-cp38-cp38-macosx_11_0_x86_64.whl (291.7 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.4.13-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 bef03a0e8df1d0e45419ceed2091f24239b6e375951a777711b77f08b63bde95
MD5 25830a925db3fa874838e9c43f522cee
BLAKE2b-256 18d9884701f76afc02ff78c13ae59ea9043da4ea3d59ad38274155d6466c8653

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.13-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 3ebd95072a7273774e92e594f9a41a650240be31a8363bb9e027c15f69aa3ff8
MD5 4a07928d8613b2b104ea40da53386ddc
BLAKE2b-256 16ab45c14c928a9f32d59b94335975efb84f65c336acbe296ce4d5227b8f389b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.13-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 a3e833bda4ed0ef337921916bf0a492888448a4f4ae6bd077709994f6851cc57
MD5 05818f298cc66b1e7fb35a20a7d16c99
BLAKE2b-256 ac25fe2e72a00c3bcf0c01c86dfc2e01c6ae3e042c4d05e338987cc77cad1505

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.13-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 91a1fc6f3040e01dc3cf4fe245ce3a58359a488bb3ceab89e76a36aee846127f
MD5 641b14039fbd3f9810748a18c927aa74
BLAKE2b-256 0ab600ec2c4c5a9174d65cbb29b3fbfbcb90091fde4ad334db96d2bca3ad8949

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.13-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ebc0eae8d54d8be784e99dacb5016401659256ef5ee364bdc0d9b4dccf629b91
MD5 59ff31727ff44ae1dc6cde957931629b
BLAKE2b-256 1d9848ba3d1643ed58ed36a2c1f6e11c6a3c15cca1b57f7660682f03dd979cec

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.13-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 b654b6dbc5d9b75f66afdf35c553b1de0edc15bae691ace2c9c2c67df009451c
MD5 3eba90fbd5155c102b6933d1dffa9780
BLAKE2b-256 fb5f3d9820bb4338733020916e807398904ac61eb586ae45894464366403a486

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.13-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 a4b87850f13c59410a2f8e42b5916b9350b1d46e5d139c6ac96a8f1a9de38d00
MD5 60249ff02b7a7b68edfc2ed98b1e44c3
BLAKE2b-256 39eb184156bf99eec800e522127cc1396eb19820a1ae97175b741851e8cf17ac

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.13-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 72660b83c20561cccde2ca3fccc6ae38921fef64874b9cc841627609a0100c7e
MD5 25bb57eb1edf0d0043492432c564be79
BLAKE2b-256 eb61990250cfb28eca623a3a57647f62195144d03eeddf3e67beb0f294c32deb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.13-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 c514c947ba2ea2060cb299f0bca9d1bbcf08e7c9c8f48d3ac77fc083cd93921a
MD5 a7fa70044d775e0fc59675decf16186a
BLAKE2b-256 ab068d5659e007e333c021b2ae7c2189f6d180c00b31ce92c27614f9e1e1fbdc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.13-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 6084070c080700666e0c1dd298b8e896f4a7d54f6a5dcaf5236d5c3f671c2ecf
MD5 9657f088d8d9e81f2bc66357c0ba22bd
BLAKE2b-256 cd74118d651b1e431ba62f9ac6e9152196cafe3e9cbf2bfabaabce9dff999d92

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.13-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 4f08dd43ef990a6f896bc9c2332451ceb795ebe915dc9c743cf4c7e78bafc49a
MD5 a472994ca71251787dd841e11cac69e1
BLAKE2b-256 20cedd2543f05f03868fad3932a77097fe4609790b21edaf7fa0a88f98ec49e9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.4.13-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.4.13-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 121843e55cdb5d1839f7a19e3c0eba1bbd3020d7f883f4b8bb327756d9a1f88a
MD5 64731eb333ff96f7781a6523582c9d66
BLAKE2b-256 b1be6905901b544c303ba5d18e3c412bc07b0e7027ff106a883a984aa537caa0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page