Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.9.8-cp312-cp312-win_amd64.whl (332.4 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.9.8-cp311-cp311-win_amd64.whl (331.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.9.8-cp310-cp310-win_amd64.whl (331.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.9.8-cp39-cp39-win_amd64.whl (330.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.9.8-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.8-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 ba456741968ebb8e1cfa90dce767fb8b2ecf907d8f097be1fa9e01e5d903e03c
MD5 c451aadc12beda9f3a75fdce7c04a1a4
BLAKE2b-256 d33489851955aee5e79cb6ded0dc8e47b0ff6f2104922f426fac96724cda0de6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.8-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.8-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7b664714b0d019c49ad08ff7a5a525b3eb97cab66d33f61c610d10a64d27e371
MD5 157b688944b5589010f40b4843b728cc
BLAKE2b-256 d8149d5feebc052f6c255135a17532fc6ea721f54855dd0ba2fea3f69a3a956b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.8-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.8-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 94d78904bc280e5b1714f39b5ea2d5bd47018cda8b92c896c5c5ad97b5452b7b
MD5 b1c771a7f6f5c4587df9204ff13bf8b8
BLAKE2b-256 cef3a313ba35c3ce0fe3c1f5ab0063e33e3876d604c94686c4d55e47eb4434bc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.8-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.8-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 82fcfdfa25b3dd450eab06f02aa41ed5b49a41166b7b84867e87c8c61f314f03
MD5 1f7ce0585eb584eb6ec11f5227eb4917
BLAKE2b-256 d05ba0a6da4949e6d5b1cb5910bdb1280e33182e080896bdc851a6314c037c98

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.8-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.8-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 9edd79bc3b866a58e9a4790550c6561cd371a5e9f1aeca5051b459a52407a815
MD5 d5fe4471e2a77192047faf1954ec2af4
BLAKE2b-256 7ad270c362c8c3e8e13f462da03f9b173ae8c72251fa16977e9cbf316a6f686d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.8-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.8-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 e873cb0e9a8ff957a180ccc0166833921ca97cf48975dd71ac7ce24f25994d1f
MD5 439520d2f6211375509c453ae238c2c8
BLAKE2b-256 12d8ab7485a1a3d81064570c37ed0e3c2b2ba046a499fca92e3da7ba21124e34

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.8-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.8-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 c3cccaeb12d77d98ddae76855f99c9efd2cf1a4cf6fbd91a656de7b33d4be079
MD5 476745d1f70db91868693b58c58512cd
BLAKE2b-256 ef771a2e39233560049febace78e6c928fbe66ba4f133bf0f6786e008333b5dc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.8-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.8-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 fb4fe20468fe9aeff381cdc1149553af80584eff8427bb3bd808b5a61b0443a1
MD5 1f3b36559a6ac02d1183fb0fffe960ef
BLAKE2b-256 9e77078f1270d95f8a276f08dcffbe1f11e9554bb545d5f0e861932567fd2d9a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page