Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.10.19-cp312-cp312-win_amd64.whl (349.5 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.10.19-cp311-cp311-win_amd64.whl (348.9 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.10.19-cp310-cp310-win_amd64.whl (347.9 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.10.19-cp39-cp39-win_amd64.whl (347.9 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.10.19-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.19-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 e2dc47aef4b98bea2b129434615fe2ede20bf4fd2f3acda2fc12031710c48036
MD5 7ebf2792a1df18c376293473a878b829
BLAKE2b-256 a496f8f743e1eb0c841d3475fb6e357bbf4af27aa6ac465a8a04565bcf501c34

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.19-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.19-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 736feeae6b905d211255b4597e573d617056b11f4d2d071433a6721426df965f
MD5 efc3025833da2cd6a86a4b507bc02567
BLAKE2b-256 9ae976c56947a6c9cff521937af09c8e901ae449fa28d2bfb37ccd8012e8cdff

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.19-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.19-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 81dcb90cf44f7c5f4b41def6011f99cd44ec7e8f7148b01b9635d99d58cb7cc4
MD5 7c09aa59b5cb39c5e3408e30e10d6f4d
BLAKE2b-256 da8e10cd3729c3ad31b991fe903fd26ac4617993511a109d030c396a6a8d059d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.19-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.19-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 77baacd6ed67a26051b0f21a032e1cd4964404af2def7b2098d78e01d53a18ec
MD5 a018428280a9cde48925e8b2cdcde243
BLAKE2b-256 a4262759d124b4043d9a65f441e06e4202d6d5edf56827042ed660378f93132a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.19-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.19-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 f0f0a1482dd5b0b892e65a5c45ffb2ec39f94b109357685df9f7ef97d85a8554
MD5 e9947dc594a8053bcc66add9014e862b
BLAKE2b-256 da85bc30d31069f4fd6eb5ee6e0c1686098ddb29c8fdd464afd2102c968836f6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.19-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.19-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d51df5682ea83f0ccbcb16d5e49e4c17b1e109eee08236b0e6b58bf5e0a05106
MD5 1d85022a1ba4d607b9bfe36eb887df30
BLAKE2b-256 ce94681c42672e6358d32724e3f947988716b003f166202fcf645fc814bfec86

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.19-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.19-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 654d9e8049937d169ea3802ec27bcd9c5e245b1cdac715d1e6ce2ad5a3fc2ae5
MD5 2f1aa508b551fa00fc6c05827e2ab86a
BLAKE2b-256 376a0c15567eac10454f5892c627393ecf751c3c65d815b04b84a117db6d3101

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.19-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.19-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 da3b4c2498385b9e2e4b93b1061e231782f2ca062371509991b1a3f3b5198bc1
MD5 9f2bf702defea9e24376fee95e446118
BLAKE2b-256 aa3d9e19e6069582c7daf8d0177fbaef2fe5ddcbf35b413fc2fb94db8c90c83a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page