Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.9.1-cp312-cp312-win_amd64.whl (331.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.9.1-cp311-cp311-win_amd64.whl (330.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.9.1-cp310-cp310-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.9.1-cp39-cp39-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.9.1-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.1-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 340f0b87ad6e1c68cd6ec77bfc4083bcefad22495acbcc09102808f692c33ca8
MD5 4c4d79dcbb0eb35607e2b8340d9c64cc
BLAKE2b-256 e0e5b81b468d691a8f6a3e7ad498515596f3f871e1c436a49757251c64285848

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.1-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.1-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 f0d1f38b1e20a9bb3dbafc8d65dff9ea59c83d44aaa78667da81178f80054072
MD5 18c319cfde7cdca81a8e3e3a54eb4df1
BLAKE2b-256 840bb4e589e0d5ea5c1d77ae00f4433ee960c32ec32257dfe5a47cedf87fc21b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.1-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.1-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 69086cfff30188d4458f51cb9d604d5afe2868dbba41f2a164549ce668eadf43
MD5 1c1a621bf7974672fdd93f3ee3744e1e
BLAKE2b-256 dcad7ebf9e6f49e67618e1cb922356cd9747bc20b2099f3e67b39a8e418a7e60

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.1-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.1-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 001efdc34f9e801dfbe93401ea1a363cdf941df7acd470c10e420b7cb3a2849d
MD5 c92f314bfe96633b5e8d4f992c43dc3c
BLAKE2b-256 d60d745324846a78ce912aa146f78a4f98ef94393167b46d4bece3ca92028533

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.1-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.1-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 dd3be2b2c0b2fe020eb1d2e0be04ead39735201ffd3ceff3c9c14656d3979852
MD5 bf6255510c9d465fe17186fec1b33ea4
BLAKE2b-256 70807a78440f49e08ba01b7a2d3f22decbdab606f2559780ebfa1c2e237e76ab

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.1-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.1-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 0ae6986efb2901f45ff0a650af3e05c4d42374ee88ebf5ebc8009ef8d83cf523
MD5 2a905a087439458fb1a66d02fa7bae00
BLAKE2b-256 5fe63e087b4399da497b07c099c1a62be49b0e81837dfca6c8ad4cec82ed58d6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.1-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.1-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 e60df9231924047eecd5749efde5134b1cadfaf32ad214b759a6e6fd82bdd8a6
MD5 23abc0cac529d8c119ea59bd0e66f936
BLAKE2b-256 aea7b4fb6c37b5f29c42bd945bb8f60ec0085c2259c1a3df85a67a9f568a7b89

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.1-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.1-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9c4cbd19506e7bc5f7281fe362715dea41ce016aad9952ecdbf0d372299942b7
MD5 578c421d5d29ae362a9bba4a75aa16cd
BLAKE2b-256 2288a7a7df5d4d829221f59b2ca3bbfcbbba9725794f16de9a17895bdc831ee9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page