Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.9.25-cp312-cp312-win_amd64.whl (348.3 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.9.25-cp311-cp311-win_amd64.whl (347.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.9.25-cp310-cp310-win_amd64.whl (346.7 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.9.25-cp39-cp39-win_amd64.whl (346.7 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.9.25-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.25-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 35e05c223df65565b6f1215d66116d2596e8692ed8059dd8d7924d371bd8a16a
MD5 796cb33ebdf48c330c669bfa8063fbcc
BLAKE2b-256 77ccd8cb3f0e215d3a6f82ddc2a26edf7d29d1007786bc49257c12edf7ce132e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.25-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.25-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 210f14ebffbc364af2c988ad53743278ede62cb68cbd94b64157482f8930beac
MD5 c7dac961537091d445f3f81894f7c526
BLAKE2b-256 c33bda16aed333856ca760034aed175ef0a2f22838d3906b30b379fcdebd00c3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.25-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.25-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 880b69277717e71c447e09aafafdb73591d201d0842d1079d51e89cd23f3bee4
MD5 65c9eb4909ce079c0a2b257b20f8c7ed
BLAKE2b-256 22bb648db63d147b8cb3fa0156c01aa679eb3b0bd34d5e3a38455900d4c387b1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.25-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.25-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 cb20d0e6574aa21bfe42f9565c969528844a61a499a9fd6bc25245109324f43f
MD5 f34d57a9ff76e95e82c625717933995c
BLAKE2b-256 47fe9690503eebcd928ddc3586b4c8ec1e88ae47bf52f4ed652fa81c7e1fb98c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.25-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.25-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 551cea3886a84179b653e5db11717bcaab93f37102c90791725ff73442370833
MD5 a0eae51e4b2174f60843a49756126dd2
BLAKE2b-256 577b89b81c065827f49d113d3eb7e469a4ca8ec903f89b97db4f67b09d154b30

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.25-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.25-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 37b528f4ef87942fcdb28e08e81b09449391d3b63f2d1e19ca85a7bf496b8012
MD5 b97e8f3e961c793a7d1ac594d20797d0
BLAKE2b-256 d7ffadd29025e0e27fdc095bb91189b1d5fb0804a5bcbc18f486a03289c122b1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.25-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.25-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 a87eaa5bf2443dfc452580efa1319b4fae4300ecc832ec1095696a9915c3f1dc
MD5 df6941318de97480b7fd516955a1d0d7
BLAKE2b-256 e3f586ecbbbc0095a0da91978de99df35d065a99d7ffcca60d0b050d7f3eedd3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.25-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.25-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 80bfe82a9268c9ee204c139ca962d3122347a739e43dac8d976193a8dfef1bc6
MD5 c4ab679c383d0d3c5c4db9191974e329
BLAKE2b-256 75acc3cea15bdfacd588a84f1692e0d2f7c814f2e6db429f3329ae1a009e4d98

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page