Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.8.28-cp312-cp312-win_amd64.whl (331.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.8.28-cp311-cp311-win_amd64.whl (330.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.8.28-cp310-cp310-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.8.28-cp39-cp39-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.8.28-cp38-cp38-win_amd64.whl (330.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

File details

Details for the file tensordict_nightly-2024.8.28-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.28-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 44e70fa28699b8359a099eeb04b38ab823385803ab2b00e9912bc6a81b2e547e
MD5 174eda8fda07b51c4fcca8daddce0d35
BLAKE2b-256 9797a1f75db3b0ed932285ed6464720717c7bc9d5914d9392dacbc5482fce3ac

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.28-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.28-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d19a59b11356a2670061d50e24bd280b70635689f618ff680ae93dc73e17463e
MD5 4e68a5d0a633daa7203210d1864d9aa2
BLAKE2b-256 95de7c04538a5cd8ec44c1edc29f406d6d2712a4ea884f48bcfbfe2b55d34c72

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.28-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.28-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 974200c921c37f10458da6e1b0dd568ff135513c371e062f50540878fd55d41b
MD5 8e6f8a74cfc913a42fcc3bde65b437d8
BLAKE2b-256 a45dffb2a8f0d1ff846cff84ee9b707bec3ca9c9b2c3d2d518df9f34a0abebd0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.28-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.28-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 0f6d15e409715b6ec36a972afb7166d96016d8eb62c7fe395c127c6845b6a001
MD5 c0831e58be92337c7997349aca29b08e
BLAKE2b-256 5361d5aff30676f35dfa8607bb9256012abb780293db42a9cd11bb3c421bcd01

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.28-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.28-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 85a4f4f529fe36014d6a0f9a1732aa550adc1578c70283dddcb5a35c1303f3cd
MD5 579db125b46e497a028a32e6012a672a
BLAKE2b-256 443a5ab59e74c622a9f7f900a78ba3946ed33285c1fb436dd232f81bc61baa07

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.28-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.28-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 25912806c9b346c236348766ec38cb1a4e744ea4c8e8945a9b05daddba06f2fa
MD5 fec925790be18ccaed9ff28a3feab129
BLAKE2b-256 63f8cd3bd4bdedb240ee499ec9371fa1ab5e0fbad65d5a4629df1a1354614a31

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.28-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.28-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 e7b66e0a973721916ae195c2bac2471d932175d231cf01e1f1451cb70f77846b
MD5 a20743cb34a5c31a4aa090a6c23839ba
BLAKE2b-256 7e83ee31d1e1688569766264e001271ee703d80d1b3912ff0f4af86f0e4d24bf

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.28-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.28-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 bc02192101f523cfe25315c3285c585cf545a7745f0e1839d9ff50445fb5383b
MD5 475bece7cb35087666219224e83b9422
BLAKE2b-256 b07e9eca7219ed77aa165206ce5aaeadd1ca673292a0fceed23fb0a8cc9c2fa6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.28-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.28-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 9a28a0e361f6088e399e37931e372d575797145431138166c4dd817b129c7da8
MD5 391c87e02251d7ecc516e7d1fe698823
BLAKE2b-256 3c6ed59f1f1cd7b7d0137b505e8b1480a5d6655b055d40af8ea58f4d54d32d8e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.28-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.28-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9558a82d0d50f82d938afeb4fbd54b6c3e753cdfc18259c9be49a97bfd2af4e7
MD5 01a8a56b3e1d2d9aeb791c123b694267
BLAKE2b-256 68b7575a638a03d66667dfe0c9b6eacc2bd4ac9001705a5cc3743c68366c9ff0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page