Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.8.29-cp312-cp312-win_amd64.whl (331.1 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.8.29-cp311-cp311-win_amd64.whl (330.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.8.29-cp310-cp310-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.8.29-cp39-cp39-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.8.29-cp38-cp38-win_amd64.whl (330.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

File details

Details for the file tensordict_nightly-2024.8.29-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.29-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 04b0b0bfde0151ea76e9463235191ec7502c172ffa5a68bec69545874ea61045
MD5 3ed42aaba654352544c7d2f619ff06e7
BLAKE2b-256 69e4700b2b0d791ee881d5726443c5e1307f23a82b6b95cea9446fee7d2deb49

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.29-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.29-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 cff68eb9d3bd7f287219fdc6911fd7c96f6521686f90ad1a301822889015bbc8
MD5 3eb503d1452155cce50e412fd023c3bf
BLAKE2b-256 daeac8cc91299ce6816082868b27acfff0482e23b4692981d427f84be4acdacb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.29-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.29-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 6b90bb1db1c27ceb8a771cbb38635a1f59e1cc0f3b08f30dfa10f1a46aac3066
MD5 f484b0637de47bba967a7b0960e6fa95
BLAKE2b-256 81f6347717361338e4bccc769b5310aaf1f81ea2885e618623aa48d1717bfde5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.29-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.29-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 103d5709b9b9fc89927b41a968c87a342809f5d8fb8f284d9802c2a0a73ce847
MD5 5f6d70bb542ba1ae6c7f28864fa9d155
BLAKE2b-256 2ac39d875e11f51cb5b4cea5af75735223c4a11d8826283e2bb09a82557f70f2

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.29-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.29-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 283903b2dc2ff42646b5c16497f31bee911fd052f1556b14c43ae00d8c3b8d19
MD5 3ef0c2a5672519f0aac6150e6597b400
BLAKE2b-256 fd6f09d4bd279e4d37e2b7a03fdfaeb70a3dfd056c2617f716ac6feaaee0da61

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.29-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.29-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 4c4dd557624b88b64387621f8667b31226e89eba006e1ac3136a52f49abe6553
MD5 0ddd70c7fba5eb21b26a691b9299fbb7
BLAKE2b-256 85a0f574ac5ced9a423bbe3a9fea1392620de71c5a7be2e84f2a00f09278026f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.29-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.29-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 0e9fc6730d9f004bb68f4027ee7e035252f34a09c5357ea0c0e31cce2313f716
MD5 c71bf97774c7bad5111a69293bdf74d5
BLAKE2b-256 2640954e1a4cdca31447da48fb1b14e6e9988837b4834b3b423a00fd1360b1b4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.29-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.29-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 cd91cf42c7d0df4834cd1a20f02c2ffb74c8735fe3bbe7d95a409a56c60d5fe6
MD5 eccf98eb4a3e9c61035d1c32d6ddc2f7
BLAKE2b-256 d68600b0556a2e238b9b0d7ea44c8b1a6d43a28fa171fe5652757bb22d86acbc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.29-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.29-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 b01e609bdbe9e3aeda37a05f29a2142b4b5c53fb6a0c21c1cb1283a1d07e0fc1
MD5 c1cfbcd50388f9eec2649bc851f95491
BLAKE2b-256 a5a99c8f4a4d224fb842445fe1c4ef44a9997f5785982205bbf63324020e4ec6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.29-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.29-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 25c72a094c67551b8c0a7cfac9803d1c8e031d867986fb409623bc54bc4b07d2
MD5 7a9dae3958b8e4abab8ca89b58617b89
BLAKE2b-256 a7909d4c8b8f6ce8b265ed077c272956dfc4f005b06483c232a7bcacc8248200

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page