Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming | **TensorDict for parameter serialization | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> data["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert data["key 3"].device is torch.device("cuda:0")

But that is not all, you can also store nested values in a tensordict:

>>> data["nested", "key"] = torch.zeros(3, 4) # the batch-size must match

and any nested tuple structure will be unravelled to make it easy to read code and write ops programmatically:

>>> data["nested", ("supernested", ("key",))] = torch.zeros(3, 4) # the batch-size must match
>>> assert (data["nested", "supernested", "key"] == 0).all()
>>> assert (("nested",), "supernested", (("key",),)) in data.keys(include_nested=True)  # this works too!

You can also store non-tensor data in tensordicts:

>>> data = TensorDict({"a-tensor": torch.randn(1, 2)}, batch_size=[1, 2])
>>> data["non-tensor"] = "a string!"
>>> assert data["non-tensor"] == "a string!"

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = data[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(data.view(-1))
torch.Size([12])
>>> print(data.reshape(-1))
torch.Size([12])
>>> print(data.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = data.apply(lambda tensor: tensor.uniform_())

apply() can also be great to filter a tensordict, for instance:

data = TensorDict({"a": torch.tensor(1.0, dtype=torch.float), "b": torch.tensor(1, dtype=torch.int64)}, [])
data_float = data.apply(lambda x: x if x.dtype == torch.float else None) # contains only the "a" key
assert "b" not in data_float

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> params = TensorDict.from_module(model)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> with params.to_module(model):
...     out = model(x)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> def func(x, params):
...     with params.to_module(model):
...         return model(x)
>>> y = vmap(func, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and (soon) torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

TensorDict for parameter serialization and building datasets

TensorDict offers an API for parameter serialization that can be >3x faster than regular calls to torch.save(state_dict). Moreover, because tensors will be saved independently on disk, you can deserialize your checkpoint on an arbitrary slice of the model.

>>> model = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3))
>>> params = TensorDict.from_module(model)
>>> params.memmap("/path/to/saved/folder/", num_threads=16)  # adjust num_threads for speed
>>> # load params
>>> params = TensorDict.load_memmap("/path/to/saved/folder/", num_threads=16)
>>> params.to_module(model)  # load onto model
>>> params["0"].to_module(model[0])  # load on a slice of the model
>>> # in the latter case we could also have loaded only the slice we needed
>>> params0 = TensorDict.load_memmap("/path/to/saved/folder/0", num_threads=16)
>>> params0.to_module(model[0])  # load on a slice of the model

The same functionality can be used to access data in a dataset stored on disk. Soring a single contiguous tensor on disk accessed through the tensordict.MemoryMappedTensor primitive and reading slices of it is not only much faster than loading single files one at a time but it's also easier and safer (because there is no pickling or third-party library involved):

# allocate memory of the dataset on disk
data = TensorDict({
    "images": torch.zeros((128, 128, 3), dtype=torch.uint8),
    "labels": torch.zeros((), dtype=torch.int)}, batch_size=[])
data = data.expand(1000000)
data = data.memmap_like("/path/to/dataset")
# ==> Fill your dataset here
# Let's get 3 items of our dataset:
data[torch.tensor([1, 10000, 500000])]  # This is much faster than loading the 3 images independently

Preprocessing with TensorDict.map

Preprocessing huge contiguous (or not!) datasets can be done via TensorDict.map which will dispatch a task to various workers:

import torch
from tensordict import TensorDict, MemoryMappedTensor
import tempfile

def process_data(data):
    images = data.get("images").flip(-2).clone()
    labels = data.get("labels") // 10
    # we update the td inplace
    data.set_("images", images)  # flip image
    data.set_("labels", labels)  # cluster labels

if __name__ == "__main__":
    # create data_preproc here
    data_preproc = data.map(process_data, num_workers=4, chunksize=0, pbar=True)  # process 1 images at a time

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    data = TensorDict({}, batch_size=[])
    data["a"] = torch.randn(3)
    data["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return data

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

data = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

data = TensorDict({}, batch_size=[N])
for i in range(N):
    data[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = data.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = data["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.2.13-cp311-cp311-win_amd64.whl (270.1 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.2.13-cp311-cp311-macosx_10_9_universal2.whl (329.3 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.2.13-cp310-cp310-win_amd64.whl (269.6 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.2.13-cp310-cp310-macosx_10_15_x86_64.whl (271.1 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.2.13-cp39-cp39-win_amd64.whl (269.2 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.2.13-cp39-cp39-macosx_11_0_x86_64.whl (271.2 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.2.13-cp38-cp38-win_amd64.whl (269.5 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.2.13-cp38-cp38-macosx_11_0_x86_64.whl (271.0 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.2.13-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 47a6b85ec763dff2c544e79043d8797b39017e964a9f7f2f2377f1f049cfa0b7
MD5 1116a8b57c14e33fd80b70d7c6dff855
BLAKE2b-256 737ffbb4a7c4394b4868a292224f58ed29fa610b63bc2640e90fe0c3803aac36

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.13-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c6a65656749f2d4088a100932eb1b40a90bc790486dc6690d0b4b33670e9f3f1
MD5 3324ddfef531f4ef373a1f8f04099d78
BLAKE2b-256 12aeac710ba4726a5a9d696567aba43ad58be1b36c746a2b4dd80d57c7f35d84

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.13-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 06b31591e954597f24fc4ea9022e23192cacbb3147c98e677387757ece66cda0
MD5 ed372118c0ce6155c241ab5be330fe42
BLAKE2b-256 de36d424fef4d5744b7456bf6d41830f6cc88d36e557267182a8e2000f4fb84a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.13-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 71475c887796c8005c37a2e9394e8e2156e19a4d773b24d09414652037ae7a91
MD5 d063c7dfb4fd17c58132e35d56465cbc
BLAKE2b-256 96e45267448716c5613740bad5a217786f9a2250abe8447b161aecfa1977fb58

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.13-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 8c307056457bf637f570cb1723a76da549dfff91cecd798ed1a7337886a11157
MD5 0c8268f47ae1202b737c418920ffa686
BLAKE2b-256 b1d86caa4e6d9ba2da335df4f758240bbf47424892294a0768ba751940c88bf8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.13-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 37112438a866de8228bededc7e71a3f9191e9b7101b1dddc4c919063228e9dc8
MD5 fd0f5ea5e4ef8e3cd230929cf467024e
BLAKE2b-256 dc4ef60e8c6c13eb6536a24f0b5d9a6eaf0d3d9418d25eb72d9c3d658a12cac0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.13-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 6ae852d1d49a65bdd6f3f350376ae33207683135368868f666cf28a4f7545877
MD5 6c58a5437ee05130e43d0d011b6234c0
BLAKE2b-256 9aa233dee3be1ba16041c074b8616fae689104bbc866b832903494d53be9a0aa

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.13-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 630a22bfa0b02ea8ae6a75dd0473ef78163c69576d705db9060e1ca0467c0634
MD5 fb7cbe1038d2f45073171b85c92986df
BLAKE2b-256 beca39c5524517e3f1f2bc18f9eb1aefd12994df2ce942e35a5599890842f837

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.13-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 57af5a9040894e4fa2178bba833aa4af69d27db77d9aaf258be35453392a9f46
MD5 5a187dc13b10899fd9833bc4c3e20689
BLAKE2b-256 3dae1600eb7ced53996188eb269990d917e944cf67357d695333dfed154d8e8e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.13-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 5be77c5dd52a0e12c53ce27c4728394d70979b170a114a5450fb89a4ce303f7a
MD5 98d5ad33e4af2bd37e64880d4ed590c4
BLAKE2b-256 d0bfb4aba51daae8cd6394b934d80a665699245abd2507cc548c69d973083a68

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.13-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 3a8a202af2da2a6ede2d67a89a04788541f703c7c02f078c700944393349deab
MD5 93ee687be9b2cbfd9e4221cfb3fd39e1
BLAKE2b-256 9309472d8c962ea93d693e7f604e43e15cf26023a595e0114b5ba939701db8a4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.13-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.13-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 e72e2ca027eab2c7197c802adb49644fe1145d39df131ae085a802091cb50c0c
MD5 211e9e3b8e7a7514d23e2ea04a1470e1
BLAKE2b-256 597cbb47309db5dd89ed34afd50e8233c43f06aadaea8c3807da9a919ea19dbd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page