Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.9.20-cp312-cp312-win_amd64.whl (348.2 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.9.20-cp311-cp311-win_amd64.whl (347.6 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.9.20-cp310-cp310-win_amd64.whl (346.6 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.9.20-cp39-cp39-win_amd64.whl (346.6 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.9.20-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.20-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 db01236c7fa96ea2ab41c9df106f1930d572c3b3f6b47be14584689d5fbb0879
MD5 1e7a64544cd90dabe2b5525405ddd2c1
BLAKE2b-256 16445e5fde08405a2fdfeec1e101e499c01e438143abf5319f3eba6857471471

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.20-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.20-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9f8b348e5c8c087f706fd80bc531f9f4f9a6b3e2228c204a7566a2825f7fe952
MD5 0d775085b56e267a672a1192eb531c88
BLAKE2b-256 c936002ceca79207b69ffa7b6a9acadecffca08013177948c77277d31ae311e8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.20-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.20-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 bbb314469dbb3e606d24891010cf6b0955df2afadd79a2a51455f3f45a2971b3
MD5 3d19831ea36de468ea36babd5d90a374
BLAKE2b-256 770759f6aea926cd3a53a4c66e6f6c262e435cb1f4e226be82588c5678a4bf76

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.20-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.20-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c52e684ecb3f97fa09a66657536a608c4973f0e0365de7b2d1fce456c741b6d2
MD5 9c88b262281ce98da284a6414e0825ff
BLAKE2b-256 50fcbaab004fcbbc2ec56c47b1a5132d2b1d80c5a46f3ff3276a0623cb644101

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.20-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.20-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 4c9a50586ffcbbdf368f52c70023dc1ce1ce0501f7b0d72f3bd3b27355778658
MD5 939be4459271abe70d0dbdb85f17ae91
BLAKE2b-256 13aedb81e6cc156989b346635a306ecbb60c119e435b36935dd0620127ae3d06

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.20-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.20-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ab0dab21a01bccda8e2e2420a1a58bdda241a965aabbd832388a38d1b885f9d5
MD5 e9b613a2864e4e0f5ac35b2da6155000
BLAKE2b-256 ffa26f4f52c417fa0f77205dd6972d6bf3414e3f4685def88e6412c0dcbccb51

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.20-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.20-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 46a845d5d15481bdcc4d725a5ecbde2b5a66dd1d135802a25b587210a1722928
MD5 de3bee19c3f5b4bfe9098ab89369a9a0
BLAKE2b-256 6049444ee4550c1844f5a4011db459050f993dfd254f1970f8b8a9ba59db40b8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.20-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.20-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 e6d6e2eb51f4281b9095294e0be4d834ee7fdcd983adc1d010001bc0c5cf1ef5
MD5 4e938426529ee7160a91d5813104dd83
BLAKE2b-256 3fdc0ba7f0ebd327400103da59adb0ef209b2b45fff4055b5338a059e1690dcb

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page