Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.8.13-cp312-cp312-win_amd64.whl (330.7 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.8.13-cp311-cp311-win_amd64.whl (330.5 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.8.13-cp310-cp310-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.8.13-cp39-cp39-win_amd64.whl (329.6 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.8.13-cp38-cp38-win_amd64.whl (329.9 kB view details)

Uploaded CPython 3.8 Windows x86-64

File details

Details for the file tensordict_nightly-2024.8.13-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.13-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 472914dc257c03d82dbfe80e230f93350cacc2d5afa3b688031aa77d643cbc73
MD5 74d4a279b7a92818f82a48b14a2ff546
BLAKE2b-256 9f552a03b027e34ec4e7728a2186354e90a4901d9a1bcfe26d2f3f37f428acbc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.13-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.13-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 56cd169bcb79477766daebd266176f39f7e95539be3ccb52355557e22fe50909
MD5 47720c44bd56885b8f9ed466fec0065a
BLAKE2b-256 3942546ae67a9540f2ffb470ab6dbfa5796fcc206dcd3dc25fd5c0b646461a53

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.13-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.13-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 87ac4ada75a41749e087fae866e7a955c245379e7104a0edcb33e210b8e1e660
MD5 0597d373f58261e5f51b5c5130867365
BLAKE2b-256 197ed14d9cf390854a09f8b6065d14a3d9a655c30f19a3e92eee1c3f667f10b8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.13-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.13-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d153d3b67748a7d5a5e03eb8f969e2dae968315f6f032585ed7a4f8e4fd3dd88
MD5 532ef45b4b6d5110419e7d0911eff8e1
BLAKE2b-256 3cafd21c00077298c6f6422cbe5326d3993cfbaea4a65aa1bd3d3c3226254a55

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.13-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.13-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 c6f65f750198caead3f66683befd663dfeb9715eaacc06c282e757cad78936ec
MD5 8a5afe69ceb5974f7407e38813828265
BLAKE2b-256 a430d7442394ec8b7d9f426642ab79c21e1cc6934ed1fa6805511afcdd3d9272

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.13-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.13-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2996645a77104063d9cd6ea4062cc48edae78359a75a578fbb7f897477b036ee
MD5 6440e4e0aa22105e4d3cd02ceb3c1e60
BLAKE2b-256 7cccd9849d3e7f104c8e29328faa6039f993ad91da903f297908088b27e48188

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.13-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.13-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 feeebce35d73a30c6f3094d5ff9f209cb4ffcba959d016b247620d12019e7d38
MD5 37f52fbe17c8613445c6d4d7d4e7152c
BLAKE2b-256 6dda239eff22876b105ed028a5a7807407fe1d83ac5a161adfd954ff7af22e63

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.13-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.13-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 148aaa86055410eaa87f96dd5f2e58cc3cbd807ea78c5ac63bf0295cf0427bfe
MD5 69c2d00fb1b9c0cb52318d98c08f515d
BLAKE2b-256 4f08607abc8f0082b2618174c2bc4a717dc22551e752e52d8f9e7dc9279c6140

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.13-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.13-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 75462f2df4dee05f0cc51c5b5d17591aa65ef0adbd5bf74761a4b72e7e0f6359
MD5 a1ef1c6714d89cb463087a1874bfff7a
BLAKE2b-256 093d165b0dd4f43f4efe0595c18dacc2b15c3741fbca47b63fdb14b08e8018e3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.13-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.13-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 df5a0b21a09982c12dd6cf04f633dec99816a24ef9f1758cdcfe5a5bdb05cb7f
MD5 618357f46b5b33a9e72116f349bedc95
BLAKE2b-256 918ba9017923aaad747976cde2a6a46fc30ea41458c81f38453f536cedc7ac9a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page