Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.10.12-cp312-cp312-win_amd64.whl (347.4 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.10.12-cp311-cp311-win_amd64.whl (346.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.10.12-cp310-cp310-win_amd64.whl (345.8 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.10.12-cp39-cp39-win_amd64.whl (345.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.10.12-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.12-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 7d61d36665d2959e0ac21027dae5727073e2bacbbe8fef6aee4c2d82e8028065
MD5 8535d82473f72db25871deeb48c939de
BLAKE2b-256 7c34a80f6b65a7dc61d9b982932d8f73452c8218944f3158031fa58f3c265f99

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.12-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.12-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c6f081467a76560bc1a60ac4aea29a74653b1ebf54a3c3f073df7ce27f57df3a
MD5 7ea730faa166f4b33b7ac05d66546919
BLAKE2b-256 fc516527603d49637a5a0e6d5518f3f2c4f81c0646ac55ec57d482ef80a6aecb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.12-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.12-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 78af141dbe117891b43ff26e79fe3ce68271c4306e0047c89cad91b92b0965d5
MD5 75f6b4555086c2ba92283ee33fecf899
BLAKE2b-256 da14e09b5302a1ef25d9cbddceab494372dd8b77f9509b1b05a4464e63ef73fa

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.12-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.12-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 c015ca909e90ef0420096a073d02cde429fbadeea5339bfe240567608d46b7c6
MD5 2aa9d75e4c3504cc222487b304aeb5a3
BLAKE2b-256 54060622a782693ba57ced4366887a96395079d5dea58ed88238cfeaa0f69b3f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.12-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.12-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 3458b350fd19df5d6349d6df1faacff76e460f9bc3c2defebe6838a2210c6681
MD5 91e93f3150dd7b5b44ee6610cf90c083
BLAKE2b-256 cf302c460ceeb932453a1f8b8192177b0b3e06c2352a620ebec6dc0a0f249f49

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.12-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.12-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 187cecb5e0532b874809863561196950983cfd530d63bdf9ae972f7b198a138e
MD5 c3aed9f3cfae5e3d9623fb1e4723a7bc
BLAKE2b-256 4dbd9944e7e4ffc0018135a54a0489f7bfe1a01da65bc396639590fe499f1d08

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.12-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.12-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 56b9b394ac6ed919c1eacf7c86495ebb24c7bd35198aa89d82b053c12892a5fc
MD5 26d42d118c9644192b26edca5e804d40
BLAKE2b-256 70c3a7ae078979ca67ce89dc422531de16eed285a6101db788c81ac2cb32ca0c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.12-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.12-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 810ca4ee9ed87620878a5532fe7d23996f6f5c5a56c74d17b6d02d97eee8e395
MD5 a6bd7db661a34ae16dbbee0214df9dca
BLAKE2b-256 6b0e7bd43517849aff733e7a9d933b627e174417b80ff981b4e59e0816c04236

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page