Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.8.16-cp312-cp312-win_amd64.whl (331.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.8.16-cp311-cp311-win_amd64.whl (330.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.8.16-cp310-cp310-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.8.16-cp39-cp39-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.8.16-cp38-cp38-win_amd64.whl (330.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

File details

Details for the file tensordict_nightly-2024.8.16-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.16-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 fe5c052de235d4f86dbcf06ffdf52e2638556a5875c66cd1da1614aebc6c7b84
MD5 88cd5db8ddea47f319541e6abb89946f
BLAKE2b-256 79b67de4ccec4d0fe4743ccf0b85cf2ff2f3ef62ffc767ce8b202aca7f035a75

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.16-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.16-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 161e3f26117110d33e24d3d2507144aecab9a0edd3a1c6bca41205ca3389bccc
MD5 915f1c63029b011a0b9b0c6cbf28ae38
BLAKE2b-256 d382897078d2385523a2f2e9ff786e7369c68f9b394d5a31b2d3a0cdfd37a64c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.16-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.16-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 c0cb41960be9ef3543725107e72a1b6cce5c89634ebcdb679511afc685231e2c
MD5 c723c51facdb3e5617e90f831a68d6bc
BLAKE2b-256 f2b79d2b53e2d1caf3df792536d4bce37812d4dcfe10aecdfb189901e672d256

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.16-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.16-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 889c3d6ac7636e5d006bef7863ad64b413d9c39804c448d22b595008c78d49c7
MD5 ef9cdb9a4199b8aa0ad904b52e6a7954
BLAKE2b-256 7185f7847f39674b55c13168325c1b0055d398af3c3adbae0f99c322a8883869

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.16-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.16-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 a29c1f6b47c4e1f69ceafa90de793e960476ae9508997e6db24e296d1a17299a
MD5 1bb61017e2a46d88bb99afca4e9f2243
BLAKE2b-256 c29d907b434017ad2b6e5e4963e77bda895fb9802341f9e8284c8bac0ed65031

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.16-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.16-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 8fc27563a9c530e3713b5505ea33d69b440367e1904caad44fcafe6ee7820b18
MD5 4cc7f2e14c42a4ed62a748f1c0c4e4ce
BLAKE2b-256 3cd1c5feefbf107c01c34bd424f6c4d573c8882e4984c38b090df9171eba05bb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.16-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.16-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 0880f02160f5f269957891b3f183e3fe5fe3e093550b2bfac04d23ed58c8c23b
MD5 e3db3c7d704214f19924bae5deaa5aff
BLAKE2b-256 ae71af5c99373558fb4f71d7cc5f16e0f73c7e3187bd4959c9fddc41bcd7c455

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.16-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.16-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 52ee506dd13cf4463fd5d199123c6ba9ff0ec8684fb880d407d615c7693c2ce9
MD5 65bc61b5718778535049b2d3898bcfbe
BLAKE2b-256 0e1e6b3ab82d7248738d20491e1ea1dc02ddfdc715f6ba23ced29b7d5e2152e2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.16-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.16-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 7fedafd588913b41284a75540ebd1381129fc49807d3484c63ec323cf1d944c4
MD5 93ff2f2ce76b9ca5e82cbfd1ee01f479
BLAKE2b-256 fc37216f87df10c5f1d2ff1978d05087aa4c3f41d3ca9e9983022c316e527619

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.16-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.16-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 78882664d95655af17a6b438afaf3a970ef5c32646ea1d54c86ba1f1a4da2ebd
MD5 a53e18ad4c73f8982bc6480b7f3c135c
BLAKE2b-256 5660a6cd79cc1b05414da8289aebb61177f742fd444ad2beb0a14d624d76a0dd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page