Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.8.23-cp312-cp312-win_amd64.whl (331.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.8.23-cp311-cp311-win_amd64.whl (330.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.8.23-cp310-cp310-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.8.23-cp39-cp39-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.8.23-cp38-cp38-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.8 Windows x86-64

File details

Details for the file tensordict_nightly-2024.8.23-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.23-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 b3d6ca9a76f6bab7160d154fde506e6145843c884a64ed35d52d86fb4d445f1e
MD5 f1564db76642cfc144d9fbfa9438765c
BLAKE2b-256 2d4a6b7c4196152915bc765ce55c915abc75951ce208e149836262c27ff28d59

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.23-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.23-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c405f82c8c2dfce426af863acde99cf2ad759ef1ff100220d25c7a0fffec5075
MD5 9de59fb442f65965c1299926fffe34bf
BLAKE2b-256 20b5bc8d2ca05341c89a817c5b1e363cd3b1ca20baa9524e84ffbffd8b8eba3d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.23-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.23-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 6e83408db042032c9b23376e28295286b135b4aa95102728e1ee93681d3e9916
MD5 4fa98a0875427c471d8ddf162b46ecdb
BLAKE2b-256 76ae29ffc445257cca82f266f95ee24a954e2f88057d4024fa31307e067865e1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.23-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.23-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 f5a09bc97d152840c64fc5e448cd0c2b31f26a9f10a9291253394410f8bd5cfe
MD5 ae2e4617955e382e470e5bc4d221a53c
BLAKE2b-256 f938b48102a19b876a22bb46749af99f502f4a396e02418f125581da3fedd85c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.23-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.23-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 58932cc2345a8f0c84d5169acee6987802bc57f512a2efcf9deb365a0b5fe998
MD5 8e5b3c26b1f73d080de4d46ddace4ced
BLAKE2b-256 fb45d5bdf3dbf67741e8eca9bce84f295d5b4430175d27adcb2d41e64689612e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.23-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.23-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 831a83dd300dca190860b3df264eabd32c6de7e48c028c6c6a43ed11848ae510
MD5 4a17e4ece7c2431902e79aeeff86c338
BLAKE2b-256 a397b562cf20d2792d2dc74a11f4f33fded75a8afe9522b48976e34217a13d9c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.23-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.23-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 a8dc85a72c7651762e0dc24a8286c222a0027d3a124e5361b6a8a0e50a04efaf
MD5 6c68bcf9a8d10aec5f0c83b3283204f2
BLAKE2b-256 5161cbf8a1e1f9d7a697049e5f1cadb79f19bd61434ca4db5cb54d6517230bd5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.23-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.23-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ce2cc60a18eaadafc0979f5887984d1d5a3b1e1e657c92f046da340300a9e039
MD5 c38ce5f984a28aa0159aeed40fee5a10
BLAKE2b-256 4d19d186872e864e888baabbd769af1904ea97fea64203cd6a448bdd471b9487

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.23-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.23-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 08cadbdcc68273f5da228e05241fddd6bec84ef73bb50bdec201ce38f7a6a806
MD5 06d5478e5e9b2c3804a9c5b055000a7a
BLAKE2b-256 503ebc4f659b77b6ef1299f9cb34a09d3a3d053556c959f436e1aa27b3f93cfd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.23-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.23-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 611682cb9c816dc96bf34e11b86eab3cd514c2aa1bcdd233cb04ef19f057dda4
MD5 4e9664fb471a59c8f9bf5aca6ddb107a
BLAKE2b-256 f0806879a0b8895468981f4c96ddfc94e61155d63310ef8bb1ae24e5dcac5856

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page