Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.9.16-cp312-cp312-win_amd64.whl (340.6 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.9.16-cp311-cp311-win_amd64.whl (339.9 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.9.16-cp310-cp310-win_amd64.whl (339.0 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.9.16-cp39-cp39-win_amd64.whl (338.9 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.9.16-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.16-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 8e8af5491d27ad81a2d9f28943adf3dff82cefdcb0a744471432cabbac7d48d0
MD5 6fb2b6dd794d373c5640b56d22d47bc6
BLAKE2b-256 81f6c0c0508a87df39ecf8a1b179c6c15b19709210b2444e927e2999070f4236

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.16-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.16-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 8cf401b73bcbd79fb754318cbdc2e8d173bdecddfc4e83dceced9690c50a3991
MD5 3cb902427bfc56111a3350099949edaf
BLAKE2b-256 8cea14ac134a67285c4c10a1b003a2d6619044cfd7dc5e5394957edacafac4f9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.16-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.16-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 7a7436ee6acadad3d0bb26ef13e0f6b6148ada415500a94a58480e3dc67face5
MD5 85be6b46b7ba9895ac660e6c1149f3a7
BLAKE2b-256 ec8e382d303b47239de34ac4518c19efd84bb8e5acd54a098fac7487ed4ff0f8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.16-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.16-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 71c39c7b2bd3a6e2a3eb9f49b76e3c3bf9d269a29bb3236276fcffc5d224c01c
MD5 5399bb4cbc4118a3c7893d056ec9d6cb
BLAKE2b-256 bbd872ccb037d4848aa33e728f6adbfaf90fae03aac733637e3d6024892192c4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.16-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.16-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 5f949611d20bb28a12b3ca345eec88f3f93b6cd6f121aab30cdaa93e61a1d5aa
MD5 a0c93885fe5ed6017661d670fda6613a
BLAKE2b-256 e603161144ace74518ca44ec7a3e61aed5f1f0911f3475c250b1cce90f9e5e2f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.16-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.16-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 eae878a854214cc1d4b1740f7892f4bd620c208822e48c8c22b017a1afbd7e35
MD5 863029b6ac08bbbcf484884c606a9d5f
BLAKE2b-256 513a3a561c4904bd65249f37432b23019ea5281e8d42383657bc11fc47452c94

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.16-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.16-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 c95ab92895737816b29b12fd9f807393a7f84a5baedf63e5bd7983fdb672a017
MD5 71b1e4809960ee837ad5eb442663c136
BLAKE2b-256 c7bace607e1aa09cf91eb1fb4d60711cb24e9b74ab6e46376a4e781c043064d5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.16-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.16-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 e59a995bf1c36d8b06dda60f21df0013d0baa610f3a412613b1641e69c4a8792
MD5 c891e13f7aa7aa536989420b44bdf52d
BLAKE2b-256 6f406e5301e89cd2b304a2d5cbef105d74ea0ed31df646f90602358cf84f5809

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page