Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.9.11-cp312-cp312-win_amd64.whl (339.2 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.9.11-cp311-cp311-win_amd64.whl (338.5 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.9.11-cp310-cp310-win_amd64.whl (337.9 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.9.11-cp39-cp39-win_amd64.whl (337.6 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.9.11-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.11-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 af0bdd6a856e7b48a35836d91ce8cdc074826f3989d1fbdccd7e4d423214a19a
MD5 533e7788aa506a3a3892d47ecb5fa1c6
BLAKE2b-256 6749125ba93fbd913ff93448f86b8507c32cfa4a71d7064d75a7498a2ddf75e1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.11-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.11-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 915ffe3e5e0b49c1deb31b6a3e50f453a9263d2f8da8b5422a30c8f31601410b
MD5 0f0d63a0455a0fcdeaad324f82f11c7f
BLAKE2b-256 7729c9c1fbf656de9e31656e56810e41cf222ec3b556013f24fce4b2ea0276dc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.11-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.11-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 3c679b56e2c2661d0f12e727bed17835a29ac78b375a6cfd96700f8c784d9f93
MD5 d1327f6e17a2524b57f9594916e02825
BLAKE2b-256 c455430ba2afec655365913a5199ce6fe22be23823ed1252c3c03eaa57e4ff90

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.11-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.11-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 554c2db90b9fea7f83e36b16bd18188336bef948044e92fd875bb53a3a680c1a
MD5 44e29ebd4e5036964b1329d4c53325c3
BLAKE2b-256 a0078d8ebdb4141b8e721a3f64d82852edaa1022db17d234dd8ebc92d854bb01

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.11-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.11-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 b71469310e24a3cf365d81a7ac80208c579ad19182416d7ae32cd4b6cfe1375c
MD5 a484b916d9051569e91c3e6f4138e4c8
BLAKE2b-256 fb528612ca66c8d460779b1939e3b6d815d69b864c452fbb74d81bb1d0143513

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.11-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.11-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7a6155303b0935408981a9a109ad9974fd97643068ef968d4cb5de41d5c01cf8
MD5 c830ca16dea74fa8269112ee9ac32ede
BLAKE2b-256 05a69e16c2f6c8dd6509d75c8429f599b8e6742ed38a4f772524c2c8cdf58d70

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.11-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.11-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 5f2165069b145697ea5bc39df7bccd5c81dca97739c4070521e4409c9dedbd3b
MD5 19e99d8be4cc9be29093ddae5f6168fd
BLAKE2b-256 5da98552942621fb1ed834115117038a5cb1fc0e47e5ccab0cb42cc865465e82

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.11-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.11-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 18c0e1226c8b32abe3125d8b7128f5a450a921810ec167b5bb93fbcb596a35f9
MD5 f871c761bf25f08de30ef1b5bd1ceb5e
BLAKE2b-256 029bd6c0e664db71d86897ff5bc966cd7418df567e7696915f0f96e9e076e0e2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page