Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.9.29-cp312-cp312-win_amd64.whl (348.4 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.9.29-cp311-cp311-win_amd64.whl (347.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.9.29-cp310-cp310-win_amd64.whl (346.8 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.9.29-cp39-cp39-win_amd64.whl (346.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.9.29-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.29-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 97ecc4d42290b11313334f1ad90a2f71674a79f802eb0203cb816989a7c73208
MD5 be8ced7ac3699096f0cc1b175676fd38
BLAKE2b-256 226aa29575a6679447e3ef028a86d8b5fa75ab8d3b50a5c6cc3731cb13d5fc49

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.29-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.29-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2236720124e65bb8be0e2522bd38532e6a6079792b0a6ab862d186488567fdd8
MD5 5e5b943fa3f0b39eb1280c77ed91327d
BLAKE2b-256 3b378d8ad8382b582ff5160bb080d08a6caa7c214aff9e2c9b3e414111726353

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.29-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.29-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 ca55471c307ca250c68dea36a663c508dd19a13bad5bb89ce514858845259bbd
MD5 0f0ecb16f73ed34cb1a9826e314cd5bc
BLAKE2b-256 35dfb35476ca9757c7126b7e54bd2c0b023e8ee7fb1a661cd110eaf79630fd0d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.29-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.29-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 e4e79614858262e2f108387ecc51783c743295b6a35b4691f6924f01128d6780
MD5 1f0c80af6d15c5f9e1f7202d96cdb66a
BLAKE2b-256 01b3c5d61861e9b45640d2b693a268eea192f57beefb3e5ccad4d9813b9eea0d

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.29-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.29-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 6da2347ff4dc898e0e865b449fec18508f2d9c8dc4c804e0ef387a810a42b3b4
MD5 ad405feb08fbbbacb873741af4a441ba
BLAKE2b-256 f6fbf337c388620ae6b66600df78699f636ac2116175bf90df95d548fc5d2f25

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.29-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.29-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 8edc5b4892465925aea1ac995c37bee587e43ec29e30ce861d66f33492a736a4
MD5 6ea7b3c99b9710130a0c7be30c84b90e
BLAKE2b-256 024c6bc41465f51eec24d768043e508ebf91d8c6661dfabc57ee839b824e4e46

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.29-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.29-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 6848dfed082343e3cdaccf010f61f5205f518c349215d968466763ff9707a79d
MD5 62df97bc3d74d34b6218caffc76e36ad
BLAKE2b-256 0c3f422282c0fd92f8c9d67148346d012a4a876d911245fce5f9d701d7c770a4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.29-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.29-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 6085179215241d620ffa4dd8afb57b40997ebb1c2cd397ce4aab6e2de6a027b8
MD5 ad425871d8217bb7cb7270139e1c3e47
BLAKE2b-256 f492c4a70eca21c4472a422a91e9bc06eefe38c0f9a1e0390795f0c2dd2439c3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page