Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.8.17-cp312-cp312-win_amd64.whl (331.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.8.17-cp311-cp311-win_amd64.whl (330.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.8.17-cp310-cp310-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.8.17-cp39-cp39-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.8.17-cp38-cp38-win_amd64.whl (330.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

File details

Details for the file tensordict_nightly-2024.8.17-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.17-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 5a5b16c085c9daafa8fd31ef69c4d5f0a398bb998e91733727a67425635b59f2
MD5 a45edd45bf049d71809e65aefa659c71
BLAKE2b-256 79e83847485b07bad917fa959b6b4ecb4ce24d0dd4e01819c56b238f2242a4f5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.17-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.17-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 6cbaeaff7af0ac99b9e4c3e1839510873eb20481b7e678468f86f4124540d603
MD5 5c29653a63eb2883fc720562e47aa14e
BLAKE2b-256 234b4e025b9019e9085706e421e21c391f228a1345af8f7348a5a09ac7e303c4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.17-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.17-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 7c49d4e7fe1af433ad6c9fa3c6f22f426b2fd9ad915a7644bb6ed084aa04d3a1
MD5 b51b1022407bec65b806b6621563eaf7
BLAKE2b-256 c46208232528d39921211804ea538f3b897af6df93523cface51d701a988b418

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.17-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.17-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 61abc24f9bd52eac490c65f44ed3948784b23830f9373fcd963788f30c62ffce
MD5 3da5196bad76f5c34d2018ad82526be3
BLAKE2b-256 29d85d541d816eb835b8dff48f185c0f92705baae52a3678c2c0e84ce03342d7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.17-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.17-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 afe7698f8948439079a1f776dbc27043ae4ddbd341b25f932ed5ad1f2cb4e0ba
MD5 8eb98dc8e0f901957f5d5d8a1db383a8
BLAKE2b-256 97fb2efc20b2b531daa03f7a1445e16a1a3b035d2e1cb877eb5571d1905ea897

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.17-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.17-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 f7c8753e4877fae751a6f5fe989448453682abea9017d89e093af88cbff0473e
MD5 b000ade860bce54e675a17eafae8accb
BLAKE2b-256 84f785c9af232c303c1168e1b5a144bc8783937583668dce4357372c65b17c98

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.17-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.17-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 2501b334d575e7185472cd469a91b1225f44be0d456ce97c203f64c9f2516b4f
MD5 704d52ddb63506bd6b30065eb2890bb7
BLAKE2b-256 aee801712591850ef157ce71ece42ec8e8dc11b5cfe6d33d7317877d32cd7786

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.17-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.17-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2690ca3c007b90669d7f15f23eb608773d94d6d08db5a5ea76de9d681984e4ec
MD5 e987459455affa1645fe0514d5b07ad9
BLAKE2b-256 ed0cf2cee77f072f0a83a792d8d5e6207cab6f5a0ce1dad4b610d1a308a8a53d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.17-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.17-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 efd2f4cbe90f5b8d021e448f1cdd9fc7809819a731e81c02286c3ff0cf9b1d3d
MD5 177caffdafc4eabb96dd32a1656be993
BLAKE2b-256 49a9732482ef9c741228bde80e9cf14ca96c8c30c079c2a137f5cd43d6f09bda

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.17-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.17-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 a78255c2fb0b527012ed3bcc0fa0659de91554841eb8eedc54e7858094c3a490
MD5 36d9bdad70c5a2f3de6c4ce8f14edb84
BLAKE2b-256 81dba1b72441b9152f17dcef80c7e6224162e64d54dbdc038ea0b23bebb46e9c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page