Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.8.21-cp312-cp312-win_amd64.whl (331.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.8.21-cp311-cp311-win_amd64.whl (330.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.8.21-cp310-cp310-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.8.21-cp39-cp39-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.8.21-cp38-cp38-win_amd64.whl (330.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

File details

Details for the file tensordict_nightly-2024.8.21-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.21-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 e769ff94d1a8d7b5398237ae167269a7aa85f3da6bff88076b98d996d91ddc75
MD5 847060c23fc2a07b5c18259b08d5c6e1
BLAKE2b-256 6baa692cd3b5849dd80a7417a2d6492c541d903a1b8ca0f2ff0db0f93c8772e1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.21-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.21-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 4f234df1ce8319de4a51d1079285ffb07046aef081981f0fbb5491a85067497b
MD5 1817ffd317bd74169e2e9d888c9d3ad6
BLAKE2b-256 f11e46ba47ce52a09733472e665df8d732977bc2b3b8c55979f1acf3421aaf8b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.21-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.21-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 0f4ae433543dda4764545164161463992fe979711f027f2945a1add4fb4f4802
MD5 58cd497cd3dc9ee8a09e0afe1b23d605
BLAKE2b-256 6fd0cfd328680769fc184daf18a397a9cc564edea41d7e6b5c35ce651a820873

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.21-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.21-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7e7e17b99432ec8ed293fa397e873926c9ded518681c374d5ab6390fc9644376
MD5 89554ab6f8f9ff0c1f35d11e36315483
BLAKE2b-256 e350068f12ce26d0d64428de0fd02ca5131a2e3f36d9806fe2428691c1f52203

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.21-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.21-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 d1c427113a694e4d33b7d0793a0a0ac6b1356e0335e1524c6a512984f8698b6e
MD5 7088f6f33c381baf9ec3292f8e2224b7
BLAKE2b-256 a883d043e26ab40272ef8ef5484de16662fdf7bce9cbbac6b86fc56f6dd35a2b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.21-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.21-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 656c5d1a31c3825f6c73073702337f9e06bc56386993e827720969682a2a3c39
MD5 04a3a6af1ddcf0f3e096b0a7b5493804
BLAKE2b-256 96a62a9c32302784cd4840939a25061c9b1245d4f8cc57b86fb62e0f31741f39

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.21-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.21-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 80d64b514b44fad717c64493405eff2c34eb42b9bd458643d000a73a3bc58eb8
MD5 f6e08fe53f32c6c96f417b181333a5eb
BLAKE2b-256 051d72815dd2529ad76626c661a0501a519aeceac0070b531d51bda4b841a6b8

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.21-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.21-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 5887de658949a55c59cdaa4210c0d661caa11df1648891dd41208815bfba8c4a
MD5 040a5ca0553294121350c481702b6bf9
BLAKE2b-256 f15fb66b56ba9354a2ab2c3cc514e7985f2eddb91368a4da0e98fd24aef83215

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.21-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.21-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 7a41d27b1f31fd18e55fe77e2fe8de0400e175ff8562abe9b21b908322e332cd
MD5 7c7a52af1e449c3f1e8f5bb836d8db28
BLAKE2b-256 273904fa0c5b7093881eb54c5d26ad470e3b2e154ec83bf8ffc3a6a7211d0a00

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.21-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.21-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b1998121fab56903572a06027bba90d7245c42112d9036a49da676ff9484c29e
MD5 37908bf929e876b08d1cdf33970cf48e
BLAKE2b-256 a190cf945bb3fba5ebf13ddf2b11a6fe5026348630484fefdfa38006df9d10f4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page