Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.9.22-cp312-cp312-win_amd64.whl (348.2 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.9.22-cp311-cp311-win_amd64.whl (347.6 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.9.22-cp310-cp310-win_amd64.whl (346.6 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.9.22-cp39-cp39-win_amd64.whl (346.6 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.9.22-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.22-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 55aca268591a9699f589baf890d2d0de2eede4cd43b0cf1525fc41228584cd01
MD5 083559d383ed2c0ec62564a04c6f23f3
BLAKE2b-256 3cbec304c161dd90eb50b0dae0f79fe83d3574c8320ed4c22f5fde42a0ae6719

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.22-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.22-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 5563c14d91108ff0d063c88cd5823e25d9b4c35d6af39f0dd2b668a459cc18a8
MD5 17591fc4628d36a4f7ac8caedd9c7754
BLAKE2b-256 788c98d59da299d86e069cfb684e39db7e475af725df779b360cbcc36e121f4a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.22-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.22-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 8093427fd25a853c5a5bbbb65a89f4a2d25112045145d25b503f5e8fb0f700c5
MD5 a25bea8da2fc21b1f9cf77fef70b42f1
BLAKE2b-256 373be9c26746ba5f71f395a80cfe74df775018fff458004258cfb25b331cd22a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.22-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.22-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2ad579b64179ac360e5330f583a94b8a10027af6aa14b1969927ccba23954eee
MD5 ab53579e82da29140ab6888f8e3632ee
BLAKE2b-256 5fa20e619a889719a5d4ed4ea6c9e3708f696f8e61c5137211c5720447efcbc1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.22-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.22-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 f1b8b6c486ab61969383e476778f0691e7328bb2f085ff09a0b2baf5453c36ae
MD5 00735a01bac9af7378a557e9eaefc50f
BLAKE2b-256 0c1899d3c02affc139957a7beed4b712cb04ae14b7579bba1e258a4890b22043

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.22-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.22-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 66e96c71691edc97ce951ccc16dae823fe81189737c393d5341853a599442bd5
MD5 126580e5a9f58c188c4dc112b639a6f0
BLAKE2b-256 90e7db5dd5a89fc6121abf407892ffbda5219160b1dbe66b998f9e020dad6280

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.22-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.22-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 74e17b76370d6f8ee58ba8a6c5eae90f1e84ba760732b360bbbedd2cc4d5841d
MD5 38155df79561be04777a1f0d22c800f2
BLAKE2b-256 a7089f4fa6db35650141a92348f7a09f12f00d5a8eadb29436e77e96f727978b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.22-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.22-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7fed5fab1a5a1f7c7dbdefc96dcd963d5003997da51c57bf5acbcecf909cc0fe
MD5 a572da625504eb459946e8e57387803c
BLAKE2b-256 3ce5cd3f6c6c5a27210eedff7fa6eb4840e25f700e98723d9c03d0d2ec6beca2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page