Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.9.14-cp312-cp312-win_amd64.whl (340.4 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.9.14-cp311-cp311-win_amd64.whl (339.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.9.14-cp310-cp310-win_amd64.whl (339.0 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.9.14-cp39-cp39-win_amd64.whl (338.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.9.14-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.14-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 66117358f33588c8504fce137bd932bf0ab9c0058c6cd030a9cf5a575b4669b3
MD5 6b3d1f5543bc124f112c51109589b433
BLAKE2b-256 3bade973f951a63c180ff9df2cae6d96d0be20dd247391ecdd057311c7eae3b6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.14-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.14-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c7ee658a7f80dfbc69a2fb7378ab164ae8d19de7953d727b2d6b3e9d5e8ff045
MD5 965efb9eedfb4567abf0a5ad9a7b5e64
BLAKE2b-256 66144df03a56234a22e6670b94a55f533d9e7fd3871931416eb1579874c0ca0d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.14-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.14-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 efc8c8887015d58378234217abbe74321211c1337c3fbbadbf6c6e3f4156238e
MD5 221e35b8811b0c255805def947bc8406
BLAKE2b-256 05a7d0449f0ee1f807445ba029d3573b636507587639b5448d45f82ae50a4f21

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.14-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.14-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 acfd133bebf8b16fe04bee794788406c9776402b8eb9bdb01d30cb9fe75772e0
MD5 142ff7282a65061104e9c722fd119a58
BLAKE2b-256 eb24a8a47c475926e340bbfca8f05bd6bda11d58281a90e175503641d2e738d4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.14-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.14-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 73925292e4ff8c4622bd748c7975e1558ca09b295f0c67a16e0e28c078741472
MD5 7edd6b2c01ffa0ac680591634ae26886
BLAKE2b-256 176eb83b19585ae11f1a0f5554637715faea8df244218862570f70b4ff5434a1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.14-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.14-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 164b4913c8ba909d44e75491e293e7557b1772731fccc35a6af8126c1b9902fa
MD5 107b276f6a98111f7a44c0e1fbf135ed
BLAKE2b-256 a2e546289f2d66b8f91a53b809a87518e536bde2ca2e13196d5e2828054730f9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.14-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.14-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 d5e9cb1f9c7525047986de77b54ed8c638b59f08de3f0d97aa605ea98a3b935a
MD5 2f9fcedf33c161049626017a491a1be1
BLAKE2b-256 6fe562216961af11dd237693e2acf086b11c9afc66efdfb2ce3e381c060b20d9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.14-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.14-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 09e814e1cc23dc7f379d25368b2d6d9b9512eb0a8cdce8ccd997b3e71b98b2b8
MD5 452b0de88686d7fbf0d5864333fc5faa
BLAKE2b-256 a78b22332920c3a22a1c145fabe394b69d04d8698d60ee62974dc3ff61467893

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page