Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.8.18-cp312-cp312-win_amd64.whl (331.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.8.18-cp311-cp311-win_amd64.whl (330.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.8.18-cp310-cp310-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.8.18-cp39-cp39-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.8.18-cp38-cp38-win_amd64.whl (330.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

File details

Details for the file tensordict_nightly-2024.8.18-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.18-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 e5185d67ade4da78d79ae643b6a8f8a5a6de085786d1f40d361bfd9d7ed2ac52
MD5 740ee6edf1f59624e5e743af890a85b0
BLAKE2b-256 1142a1af26d448de627124b00c7e6f38de6deb34437594f5b1015d9725c3c298

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.18-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.18-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 becfce626fec5c15984fb6669fcd53edb0f97ac035335efda4d94261878f4d95
MD5 2dbe8e0d756fe830a8433787760aad71
BLAKE2b-256 60edc778b39e239cdb817aed2b1e7d8bb05834666375ee39d49abcaf398dc578

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.18-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.18-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 554b90305892103a0c4ef2c4735701a5fd892d0007ccf6cbaf55e83bff1e03b4
MD5 c36a4014e9466de8f997fbbf283d3f5f
BLAKE2b-256 7e8123edff93d5fbb83d704d6bbf8a94cab0e9a41fec6a8c8be0f6c1fde962fd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.18-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.18-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 fdcb3856a7efb6c707ba993a7a9e57ecf5d48a9a34b7730da22d6e5789729a1f
MD5 11c160f5b3d23c08f45f17bb1ffc31a7
BLAKE2b-256 ac9a8606f7f2a31c0ecd0ee89efddb6cc4ecd6f17a4cb4c801a0bcdd5e2dd0e9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.18-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.18-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 b8b18fc724e2bc4a5e7122e2030f03875aeb0c581421619752c370e5b79ae448
MD5 c146a9f3ddce01d7debfa4b43470aca7
BLAKE2b-256 d009fb2153023e8e08c21ea4776c7f578364f576300e28fe0fec4904cda4d293

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.18-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.18-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 09d76f240fc209f2a60844d7250d19ebb15767c3ac4aad2e700d9bfd3c6fcefb
MD5 169c545ef305ce8899979ae62eec8ebf
BLAKE2b-256 916fa2a711a19721e4b91b69db53633af5c1264b88093507c0f5b11bb85a5a8c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.18-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.18-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 9e85618f9d8fe4cded8376115221d3860450c39f8b39e752cbafc2382a24d8e3
MD5 5f818713fab5676293be8d3835c1ddd7
BLAKE2b-256 3e39ab4933dec46d271955d2d494e756d989b89b018f080ca5bc1f56f531f4ed

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.18-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.18-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 391f9775f79b53149f4493e22aff20b0670ffd72e1bb88e7227816ba4035af9e
MD5 029177857bb74a20bbc1b849158dee92
BLAKE2b-256 03ba195f16377414852b25934b78df98243d8e47bc803ee80a9a8693405c8c06

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.18-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.18-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 03168cf701740c303b3dcd761217c2c4d3ad1ae6ad4f22bc7daaccfdaf693a08
MD5 d96d5c6cb7671bce93d5dbeaaa3df2a0
BLAKE2b-256 32f327cca979b1003369ca48b77fb59528ee9db13b265af079d39805cd83a7fb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.18-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.18-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b437dbda7d59c08bcc4766d82e4f1fd0bdb4581129e1d4dbf28ffa79aaa585b0
MD5 725b4c74f3b042880f9ee04813fe27ef
BLAKE2b-256 c374b8a351b7e1d9da8acb87c7a9e230d89609a43ec1f4c600b028294dc977be

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page