Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.8.15-cp312-cp312-win_amd64.whl (331.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.8.15-cp311-cp311-win_amd64.whl (330.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.8.15-cp310-cp310-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.8.15-cp39-cp39-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.8.15-cp38-cp38-win_amd64.whl (330.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

File details

Details for the file tensordict_nightly-2024.8.15-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.15-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 ab96f0656cf7778c423136808987bc64299b766a718fcf721bf84b063100563e
MD5 6ac22aac70e1ff51d6cef2172b0cc122
BLAKE2b-256 e947927391d9036ac4329b89e341e48240a4c0e50099ac99eacf965e15de5718

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.15-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.15-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d0f4c5cd450804d47d62299683d0d6084eb987f904b71a5761410f7199b812a0
MD5 25ad384b20a16cdb4cc56a23957752e8
BLAKE2b-256 4b5d97ef7198e0981132779d7717e2ade4889e76e82401dd48bd3f385de34c1f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.15-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.15-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 213298137bb5d3ef3f5e7d353cb3a19798e2387686a7afad67c263abc2c201fe
MD5 100a5f746697f3c71291b3f9e80461bf
BLAKE2b-256 cfeb135e1d59c375eca7535eced81a1118b8f30478743a76a51673fb561f087b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.15-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.15-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 04a200311f0630ebe1f98867ad47caeade649ec61f0bcadc8ab84c4a907e211f
MD5 435cedfcd0930d4d28b079c6fcad5464
BLAKE2b-256 d9538bae28bae896c9a9cce5dfa2e6748bccb555d9a9b7b2fb50c4ab4415bdcb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.15-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.15-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 0ce44b14669664d442f796ffdb7511f681fd09b3e7e6ff7d254fe0df2ce7616c
MD5 9c1c20c1a3b6eea470ac6a3de9efefad
BLAKE2b-256 8ac8fb519da9b01008020b23639ca8cd819956b9582468fab141721a280efcd6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.15-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.15-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 8f26a21c3a55b4818d38f54dd765c338a29e207e11c1ebfc45c88e5fc894a6c4
MD5 5b29149ae1f79a9b3d5d373a0a8127ab
BLAKE2b-256 8a7e0db01f7f88848b6bcb2b10cfe49322cf885dfb98dac8c208c9107ed17c29

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.15-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.15-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 81322bdc1b0a7eca096683bf9d01d87acb2325dc79fd414f6d9603287c786363
MD5 5ba47eac9fd764054aaa0f9e9673ea4b
BLAKE2b-256 91698506209c97e0b8f919352c2594ba0e715eb57f07ea9b7f25b2a308e09279

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.15-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.15-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 0a8e494ced85dba389e21335f9688e753dc77b31e76d737cbe792d06a867e7eb
MD5 e4baaaf1d4bc95663cf64f27a3b5eab8
BLAKE2b-256 2919ef011b79f9ebfaa0d76463e5195edb0f069575a830819e4be49f98c45966

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.15-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.15-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 f3d5623578a492f0b988178d287630bc19f740b479b96f2249cceb747e45e1b9
MD5 c52553e3af7bf45f2a1a1eb81d60fd42
BLAKE2b-256 fdb4345157d67d580addd870eccaec0db16cfe7f3af7210b7ca653fd4c6c6635

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.15-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.15-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b9cdf4e8c3f4b84ba7aad96de5dd9d44a516bed5cecf13b49776df2dba467dfe
MD5 65949c745e206bbafa9234e473446d51
BLAKE2b-256 7f28988d6881fee16ca6379971a63cb7717762aed174e2d68fff82a9048fb99e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page