Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.10.7-cp312-cp312-win_amd64.whl (350.2 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.10.7-cp311-cp311-win_amd64.whl (349.6 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.10.7-cp310-cp310-win_amd64.whl (348.6 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.10.7-cp39-cp39-win_amd64.whl (348.6 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.10.7-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.7-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 0290d53e2d1f47eb1b15a0c41f26203afdc5b97bb21dce15966ce89d48fbf0c1
MD5 58278e32be6113c29dbc2bd07417330d
BLAKE2b-256 cfc6f8b0cb354bb50cce737ec3c8bd0c5aafd22f1abe32918caf215e86376a91

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.7-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.7-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 896d99b6aacd6db9864716a11e2449615649ad71c4db48436d3c42fdb44bb53c
MD5 91e9f90ec81fca48160436a68c759cc1
BLAKE2b-256 f947bc18a27a30883c828b9cd8a9ae98361a5b5163a185db5281a289825a5c9a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.7-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.7-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 014483dd2da36de68c489292fc9060b8d164ffc70b7155e58158dde80da1db17
MD5 ffb08d59c459d97ab236a44231ae9c8f
BLAKE2b-256 a25dfaa91503f92df9d49a66ff0b80db7136a291cd07ae237e5e0201944dbc78

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.7-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.7-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 bd3c2715bf805f68879dc746e2d6b02f5c8cb9677cb417281fdd63bd96344eab
MD5 00ec1c6b55401513a1b36a57ef2c1bef
BLAKE2b-256 3782aae2a3ce22c9de3a8e27a2c6dc0d89fc05076e67e86e5dceb50e50f15398

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.7-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.7-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 b7dd62cad18b9d9421dc0e8118e0442a66ca4bb0be25151db89c0b4d9a39f165
MD5 47f0040533e996f2b2099bc103b219ed
BLAKE2b-256 ad7179528a16d2045bc3238b3d4452ab3641c1c3737051aaf7eeaf32d6cc611e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.7-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.7-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 dc206f5f8c50791ec96937d81a1e1b05d2d6aedb9a1803c87113e7acd4a71b2e
MD5 353ead3804c311f81b66138d016a76fe
BLAKE2b-256 50ab44f2cd357dd494681072341688dfaeaba22e1704c0b32d087a2577469a6a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.7-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.7-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 b0cba968ab6c5dd7404682a6150a5e40ab46faf21bd4e406c5d99037f493bd52
MD5 0b669eb48b8b06970e89c44f7c217169
BLAKE2b-256 f6f428238b80758523c4e49cfea903ee46d38c6d5646ed0112a03dcf917eb73a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.7-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.7-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c50985e8bfc978f6169e65c285a74d2525a8f0af25976af1c89cfd62ef55f9f0
MD5 6589eca3efec56c13c14b8afb287a4ac
BLAKE2b-256 e7f5f8b861a60669a88e5dec9694f55ca3644527cb98aa729632aaedff28b713

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page