Skip to main content

TensorDict is a pytorch dedicated tensor container.

Project description

Docs Discord Python version GitHub license pypi version Downloads Conda (channel only)

TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or storage and many more. The code-base consists of two main components: TensorDict, a specialized dictionary for PyTorch tensors, and tensorclass, a dataclass for tensors.

from tensordict import TensorDict

data = TensorDict(
    obs=torch.randn(128, 84),
    action=torch.randn(128, 4),
    reward=torch.randn(128, 1),
    batch_size=[128],
)

data_gpu = data.to("cuda")      # all tensors move together
sub = data_gpu[:64]              # all tensors are sliced
stacked = torch.stack([data, data])  # works like a tensor

Key Features | Examples | Installation | Ecosystem | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, dispatching them on the leaves for you.

  • Composability: TensorDict generalizes torch.Tensor operations to collections of tensors. [tutorial]
  • Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile. [tutorial]
  • Shape operations: indexing, slicing, concatenation, reshaping -- everything you can do with a tensor. [tutorial]
  • Distributed / multiprocessed: distribute TensorDict instances across workers, devices and machines. [doc]
  • Serialization and memory-mapping for efficient checkpointing. [doc]
  • Functional programming and compatibility with torch.vmap. [tutorial]
  • Nesting: nest TensorDict instances to create hierarchical structures. [tutorial]
  • Lazy preallocation: preallocate memory without initializing tensors. [tutorial]
  • @tensorclass: a specialized dataclass for torch.Tensor. [tutorial]

Examples

Check our Getting Started guide for a full overview of TensorDict's features.

Before / after

Working with groups of tensors is common in ML. Without a shared structure, every operation must be repeated for each tensor:

# Without TensorDict
obs = obs.to("cuda")
action = action.to("cuda")
reward = reward.to("cuda")
next_obs = next_obs.to("cuda")

obs_batch = obs[:32]
action_batch = action[:32]
reward_batch = reward[:32]
next_obs_batch = next_obs[:32]

With TensorDict, all of that collapses to:

# With TensorDict
data = data.to("cuda")
data_batch = data[:32]

This holds for any operation: reshape, unsqueeze, permute, to, indexing, torch.stack, torch.cat, and many more.

Generic training loops

Using TensorDict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Each step of the training loop -- data loading, model prediction, loss computation -- can be swapped independently without touching the rest. The same loop works across classification, segmentation, RL, and more.

Fast copy on device

By default, device transfers are asynchronous and synchronized only when needed:

td_cuda = TensorDict(**dict_of_tensors, device="cuda")
td_cpu = td_cuda.to("cpu")
td_cpu = td_cuda.to("cpu", non_blocking=False)  # force synchronous

Coding an optimizer

Using TensorDict you can code the Adam optimizer as you would for a single tensor and apply it to a collection of parameters. On CUDA, these operations use fused kernels:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Ecosystem

TensorDict is used across a range of domains:

Domain Projects
Reinforcement Learning TorchRL (PyTorch), DreamerV3-torch, Dreamer4, SkyRL
LLM Post-Training verl, ROLL (Alibaba), LMFlow, LoongFlow (Baidu)
Robotics & Simulation MuJoCo Playground (Google DeepMind), ProtoMotions (NVIDIA), holosoma (Amazon)
Physics & Scientific ML PhysicsNeMo (NVIDIA)
Genomics Medaka (Oxford Nanopore)

Installation

With pip:

pip install tensordict

For the latest features:

pip install tensordict-nightly

With conda:

conda install -c conda-forge tensordict

With uv + PyTorch nightlies:

If you're using a PyTorch nightly, install tensordict with --no-deps to prevent uv from re-resolving torch from PyPI:

uv pip install -e . --no-deps

Or explicitly point uv at the PyTorch nightly wheel index:

uv pip install -e . --prerelease=allow -f "https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html"

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

tensordict_nightly-2026.4.22-cp314-cp314-win_amd64.whl (592.3 kB view details)

Uploaded CPython 3.14Windows x86-64

tensordict_nightly-2026.4.22-cp313-cp313-win_amd64.whl (590.6 kB view details)

Uploaded CPython 3.13Windows x86-64

tensordict_nightly-2026.4.22-cp312-cp312-win_amd64.whl (590.6 kB view details)

Uploaded CPython 3.12Windows x86-64

tensordict_nightly-2026.4.22-cp311-cp311-win_amd64.whl (589.4 kB view details)

Uploaded CPython 3.11Windows x86-64

tensordict_nightly-2026.4.22-cp310-cp310-win_amd64.whl (587.4 kB view details)

Uploaded CPython 3.10Windows x86-64

File details

Details for the file tensordict_nightly-2026.4.22-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2026.4.22-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 ecd5f2d04c7ba3edbe3216a7dc7f02269b398dcda1ca63f1773c9ca2b2dd9b97
MD5 d443008ec69bfcc6140303428cbda7be
BLAKE2b-256 33577816c1038821c62cb662d44b0bded0cd9377cdcb24eda90470c650c4e0be

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2026.4.22-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2026.4.22-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 576a552e8657c24af880381d37d9f88ec1f0915ed8c2cea2a51f52ffac1fc0bd
MD5 634454f1f85586e409ecff6e23d257d4
BLAKE2b-256 69e2bf3daf8656b648a4bb119cd4a6a11518f9a35743a8b592dae9eceec9bc0c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2026.4.22-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2026.4.22-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 e77abe8ab60d91dd4ffce77cdad593c6c63ac4cd39b4842691b55f5eeba06172
MD5 59afd14f5660333b9500d2a9bf0db3bd
BLAKE2b-256 06dcc598c9965bbb51d474de98b270ba6118c0a5f753ee817ac4f9af6e1bc051

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2026.4.22-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2026.4.22-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 e46a44c0bb7dc574a27c6a7650c37ac439d78a8e0ffc9178ee889fdafee10c8d
MD5 29b2a161cac94d74522bd707083a063b
BLAKE2b-256 d30553d4e5668e720148e06c567c0fc2197012b7af9e58bba7eed48aab823ff2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2026.4.22-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2026.4.22-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 11847ddbf26fdb26e7f259c194acec55a98cc0eb39f7e63a90af472c1e4c4cbc
MD5 b4b15931be5e5bd31bf3d9e41f748309
BLAKE2b-256 cfeccb80ad3758424bb79b9ada852150cb9bfe1058c96c4ecffcb0c5c85bf720

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page