Skip to main content

No project description provided

Project description

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming | **TensorDict for parameter serialization | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> data["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert data["key 3"].device is torch.device("cuda:0")

But that is not all, you can also store nested values in a tensordict:

>>> data["nested", "key"] = torch.zeros(3, 4) # the batch-size must match

and any nested tuple structure will be unravelled to make it easy to read code and write ops programmatically:

>>> data["nested", ("supernested", ("key",))] = torch.zeros(3, 4) # the batch-size must match
>>> assert (data["nested", "supernested", "key"] == 0).all()
>>> assert (("nested",), "supernested", (("key",),)) in data.keys(include_nested=True)  # this works too!

You can also store non-tensor data in tensordicts:

>>> data = TensorDict({"a-tensor": torch.randn(1, 2)}, batch_size=[1, 2])
>>> data["non-tensor"] = "a string!"
>>> assert data["non-tensor"] == "a string!"

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = data[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(data.view(-1))
torch.Size([12])
>>> print(data.reshape(-1))
torch.Size([12])
>>> print(data.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = data.apply(lambda tensor: tensor.uniform_())

apply() can also be great to filter a tensordict, for instance:

data = TensorDict({"a": torch.tensor(1.0, dtype=torch.float), "b": torch.tensor(1, dtype=torch.int64)}, [])
data_float = data.apply(lambda x: x if x.dtype == torch.float else None) # contains only the "a" key
assert "b" not in data_float

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> params = TensorDict.from_module(model)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> with params.to_module(model):
...     out = model(x)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> def func(x, params):
...     with params.to_module(model):
...         return model(x)
>>> y = vmap(func, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and (soon) torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

TensorDict for parameter serialization and building datasets

TensorDict offers an API for parameter serialization that can be >3x faster than regular calls to torch.save(state_dict). Moreover, because tensors will be saved independently on disk, you can deserialize your checkpoint on an arbitrary slice of the model.

>>> model = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3))
>>> params = TensorDict.from_module(model)
>>> params.memmap("/path/to/saved/folder/", num_threads=16)  # adjust num_threads for speed
>>> # load params
>>> params = TensorDict.load_memmap("/path/to/saved/folder/", num_threads=16)
>>> params.to_module(model)  # load onto model
>>> params["0"].to_module(model[0])  # load on a slice of the model
>>> # in the latter case we could also have loaded only the slice we needed
>>> params0 = TensorDict.load_memmap("/path/to/saved/folder/0", num_threads=16)
>>> params0.to_module(model[0])  # load on a slice of the model

The same functionality can be used to access data in a dataset stored on disk. Soring a single contiguous tensor on disk accessed through the tensordict.MemoryMappedTensor primitive and reading slices of it is not only much faster than loading single files one at a time but it's also easier and safer (because there is no pickling or third-party library involved):

# allocate memory of the dataset on disk
data = TensorDict({
    "images": torch.zeros((128, 128, 3), dtype=torch.uint8),
    "labels": torch.zeros((), dtype=torch.int)}, batch_size=[])
data = data.expand(1000000)
data = data.memmap_like("/path/to/dataset")
# ==> Fill your dataset here
# Let's get 3 items of our dataset:
data[torch.tensor([1, 10000, 500000])]  # This is much faster than loading the 3 images independently

Preprocessing with TensorDict.map

Preprocessing huge contiguous (or not!) datasets can be done via TensorDict.map which will dispatch a task to various workers:

import torch
from tensordict import TensorDict, MemoryMappedTensor
import tempfile

def process_data(data):
    images = data.get("images").flip(-2).clone()
    labels = data.get("labels") // 10
    # we update the td inplace
    data.set_("images", images)  # flip image
    data.set_("labels", labels)  # cluster labels

if __name__ == "__main__":
    # create data_preproc here
    data_preproc = data.map(process_data, num_workers=4, chunksize=0, pbar=True)  # process 1 images at a time

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    data = TensorDict({}, batch_size=[])
    data["a"] = torch.randn(3)
    data["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return data

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

data = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

data = TensorDict({}, batch_size=[N])
for i in range(N):
    data[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = data.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = data["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.2.9-cp311-cp311-win_amd64.whl (269.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.2.9-cp311-cp311-macosx_10_9_universal2.whl (328.9 kB view details)

Uploaded CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

tensordict_nightly-2024.2.9-cp310-cp310-win_amd64.whl (269.3 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.2.9-cp310-cp310-macosx_10_15_x86_64.whl (270.7 kB view details)

Uploaded CPython 3.10 macOS 10.15+ x86-64

tensordict_nightly-2024.2.9-cp39-cp39-win_amd64.whl (268.9 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.2.9-cp39-cp39-macosx_11_0_x86_64.whl (270.8 kB view details)

Uploaded CPython 3.9 macOS 11.0+ x86-64

tensordict_nightly-2024.2.9-cp38-cp38-win_amd64.whl (269.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

tensordict_nightly-2024.2.9-cp38-cp38-macosx_11_0_x86_64.whl (270.6 kB view details)

Uploaded CPython 3.8 macOS 11.0+ x86-64

File details

Details for the file tensordict_nightly-2024.2.9-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 afc57fe5cc86394c804cf17994e347436d86a87f7a96c75d9044fba664d2b0a8
MD5 fbd11b57604ce27877c1cde9b75ccb50
BLAKE2b-256 a2c64aaf2031338272c1a707c80a721c6ceb54fe6ece4a9e24525667fb6ebe44

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.9-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 e62833a6935a2fe0b9c52ac383f294ba5db1fdff50153a3963301be923164b93
MD5 e949c29706f5185c65c205c45126328c
BLAKE2b-256 8c088baea98031c96d41e0022e62b6ac3300be3125c65262a664beecaff65523

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.9-cp311-cp311-macosx_10_9_universal2.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp311-cp311-macosx_10_9_universal2.whl
Algorithm Hash digest
SHA256 4f82491e23d0a2327001774d9914ce68cb4d5c60e0609e575e8422a1ba23b1e0
MD5 bf6f25470c25b4887d6f33c27d603248
BLAKE2b-256 ece955c575deeaa29d0bd7d14cfb2775d7a5771bc76eaf8335c7009edef74359

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.9-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 0b772efc2c41d1f4197aa9ec04e53f753d88046e9c51320fb02eb6af460260c1
MD5 3feadcd33c200caac444874c6e668f8e
BLAKE2b-256 6c24561fd794355ae4ef36ebd2e8ec10b01d0a3106d4f36daf8421b36bea30d5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.9-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 0bd0bdff6403f2374122aeeb50b4f5637ba2a08fc4af96ed27cd4e28469237ff
MD5 4a5c76223d91e28c76830be7e7146a4c
BLAKE2b-256 4bd1d18ea1a827b970eba428d72b1385ca92cd0e7604b502a14d2f03c1f7b54b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.9-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 8b3001b414efc0be151527a346cfecc9dc28fac4b4ca664b9f687e2bcbaa3991
MD5 5e34a939250c49f00084f111a2b0ab0e
BLAKE2b-256 0e8b02c274c4ad81ccc266bbbd13fabe9198a3f1e5538b48d67125ec91217eee

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.9-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 470fe4d6660066587e59ddb7c239cc375f44bee31f37d5c6429e24023612a73d
MD5 461276c96ef9e72cc28f02908d941c98
BLAKE2b-256 5f75ce8a0cdeca2ce5902b591881a7fc9f4a0956a9e51f74db06affbec8beed4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.9-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 a2d648c8797e09d5bf4161e20ec9085ec263929f6c9464a6094dc409bbb72c84
MD5 c2102fc41e5b5b36e9844fb20312d4d0
BLAKE2b-256 c084e474ef9a859404b1c7ebce14dee191635338c25a7f4f54e79619f3641565

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.9-cp39-cp39-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp39-cp39-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 823a9d9b24a5eb3da1763a26fe0d58b0ad14684db7f9a9bae7d137d2f41430d7
MD5 1659153630337fd9247448aad98de49b
BLAKE2b-256 6bb81e052fba6488e9a7d036e947989f6fc2d2a51433e2186cc882716e114474

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.9-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 47b4bc078c0d320d9b3e5200b8aeadd6a8e4b68564251b03d732282eaf8c3819
MD5 35ac6cf3ab559f5ba965685983fa8a61
BLAKE2b-256 a044a0707562f5d668615d9068811e13909d1b8ded8f4964fc93d033f9357a91

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.9-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 f3f69026e68e7a23b6cb39f18b7cb1297a7eec698e5b47b8755559ca982aae5f
MD5 bf09ae5a2e0429cb436c47085372f4c6
BLAKE2b-256 e3f9e654c4f270e524eea72cfe2c86b4328e097e6b82b9df807e25f96165f231

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.2.9-cp38-cp38-macosx_11_0_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.2.9-cp38-cp38-macosx_11_0_x86_64.whl
Algorithm Hash digest
SHA256 7cbe05c4116e48d9f134ce2d1a49d22cb6a13e3fe9889bd67e620b3b23b1983c
MD5 bb085f3b10819ee309945a40b096392e
BLAKE2b-256 47c65acab25e25ce92b5e663b2756b4d2c6255ed775daed8984e415b2b138c48

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page