Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.11.14-cp312-cp312-win_amd64.whl (360.3 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.11.14-cp311-cp311-win_amd64.whl (354.5 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.11.14-cp310-cp310-win_amd64.whl (353.5 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.11.14-cp39-cp39-win_amd64.whl (353.4 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.11.14-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.11.14-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 b244f9ae8dcbbc7da3cefe10f298ea33f3bc377431d481790dea2058841bf9db
MD5 d7fdf174204d3193f51071323c4ca9da
BLAKE2b-256 95c2808c66eb8d450ce59cc58369b63c5e909d2b76f44b8ff4390431d50a0d7b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.11.14-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.11.14-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 5451c470aa16289fd7f5e231d3f9b4176162604896c382a3d4b799b75a0318a8
MD5 479da62d736d4bb0029314903cadd535
BLAKE2b-256 315c83668de0cb09d5813439425670a97e95188bc36c854ec9856e775cf5c387

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.11.14-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.11.14-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 5cff2ae9c70abddac7d244f87a70b9d9740da8a95405be564fdbe224fe9aea08
MD5 2b32db11d420a390c0dbca4ee7f25167
BLAKE2b-256 3c7c2ad83755c52752a94f44d803772eb357bddb363024e7cdb57d311430a55b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.11.14-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.11.14-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 aa22c7cb215f6eb184ec2cc5b9613e856ec4390e293d88fa43530872d284369e
MD5 b9a58512df20536fb6eaa3ce4b3d0c70
BLAKE2b-256 760edbd01985211420ad37cfaa376d2db86a931d02090e96b82e7fb91438f0a3

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.11.14-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.11.14-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 eb8781d95ab2d4057e6ee6e0160998766e88a94b67defcf2913f292d3126fb75
MD5 d142080120284ddd82fd84e8226b18dd
BLAKE2b-256 f499cf8b3709cc13f156a885a43bcb53dc248a120e9c8f3c3eab968509644edc

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.11.14-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.11.14-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 4e331bc84c3bcf7a57b108480490ea59d712111edf5d27041f2ce43abed5d13f
MD5 3b600ccde732630cb338199866845c33
BLAKE2b-256 5609a4c5dce5aa4893d5f38954adaff63d422784e8ad17b9b221ef8657fefae1

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.11.14-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.11.14-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 c6e0cc267db8f73b98259e6c8deda8b2923e03b7f630637a496b3107d99a0c0a
MD5 0f2359ea96ff542c2664fa4ed920e26c
BLAKE2b-256 a086d861db24273e6d6122b9a1a48d5b4c9edaae741eec20d589f0abc7f456a6

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.11.14-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.11.14-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ce16418ffc6988daecbb3e43e84a766775114f47c6fa08e7797a80c58c4b1c1a
MD5 7b0b23f5dc129f1f2f772686db952721
BLAKE2b-256 2126f3e3c3580451b0b768475253078c143a11ab4487fa6bc30622f4b411d7ab

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page