Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.8.24-cp312-cp312-win_amd64.whl (331.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.8.24-cp311-cp311-win_amd64.whl (330.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.8.24-cp310-cp310-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.8.24-cp39-cp39-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.8.24-cp38-cp38-win_amd64.whl (330.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

File details

Details for the file tensordict_nightly-2024.8.24-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.24-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 910a53471b56574e20a5182a58fc4117cfe8cae6b3aa15108c99bf912ead3c97
MD5 c6e57f402e09e289519a06d23959a60c
BLAKE2b-256 d9ec8ae922963281f3546945bd81e282df4947eb9677016f89ec56ca03383a91

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.24-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.24-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 a5a8b5a10daea039948802160c27c2987201eef771e4f510ae772dfbb78f03c4
MD5 2f9c1ab361b020d10296a202e8e681be
BLAKE2b-256 3c42301bdab2ecdb4194321d0a30bfb71a1d9affedf6b5f0e135dd0dac6e348f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.24-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.24-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 a8921ab1e2080cd35624ab5269eb8641408f60854f267b2586ae043b15375031
MD5 53330fbf99bcb11d8998fd49c36b89cf
BLAKE2b-256 ea00c8d729abd0f9d6f55195b6365e2d43f6b903db90171f0f63f4045928cd53

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.24-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.24-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 1464a00185fc6e768ebdb7bfc8a23cb93cd109951070cb7818d48e851a5c9159
MD5 86ee7f645948833d8143e1e371a31a54
BLAKE2b-256 63cf806c00a283ba7cb7776984c9502178936130a423d5dc67270b413307afbc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.24-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.24-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 92583a8d070583271e5860f7c4e16f6b97d22974b892785a10bce545a411f620
MD5 2f8433f25b3228b6e16574f93aa9fc7d
BLAKE2b-256 0a751935d09438f61ffaaa97e9e640aff21593cd1047626f0171681730fba749

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.24-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.24-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 fa2bd001bdd1927e6f7a6746c9dd5ea9c4d4b1f7957f1ba3fbf366d136f244ed
MD5 772faf643d0496c5368ba344954ab4bd
BLAKE2b-256 a8a119af1bbe9e490dc8d0b5f58e52b3a1d20b945efe1ba617749379bbdc131e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.24-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.24-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 7759d112d21ca48d44f761380c92ddfd53ac1b0041f39004dc7952b5bc20a5b0
MD5 64db87f91a6dbcbd825fd1772d60ea0a
BLAKE2b-256 61ece9a3744729f6447b60d01f7ff3a0ecace7645b401898103a28693b1e85d2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.24-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.24-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b57cb51f9e1459f76e840e2c039323cdeeff3fe84c194f1dc3433b8ca76221b8
MD5 8f01a0c9c8465e50d1d9812bdff5c244
BLAKE2b-256 4c9017dd19c962662cca991703b1105acf35ca9e3600acb4cd4f6bad490349c9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.24-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.24-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 df27b193ee4da8fe0144cc5629fdd58db271e84f5282d57611d979bf3885c2a4
MD5 64f959f0496969938df4af2152a752c0
BLAKE2b-256 28852c3aa148a8da6762d001a0a809a3c97829ce56f49f8101a5add21eeab801

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.24-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.24-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ac2234d3a1722846b3f4f56248a2fca4356fc7b395a75a01857aa15cf8130acb
MD5 cafe69cd4959be6f12625585a324127f
BLAKE2b-256 967cc767064243b636f331c4d9723a8cf3963cafea44533385e302668074fcce

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page