Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.10.14-cp312-cp312-win_amd64.whl (347.4 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.10.14-cp311-cp311-win_amd64.whl (346.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.10.14-cp310-cp310-win_amd64.whl (345.8 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.10.14-cp39-cp39-win_amd64.whl (345.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.10.14-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.14-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 1d4b78bf84176064150bbeb7760afa25241bd60eb62cd0aa37985fd401ed5d78
MD5 4c0622487bfeece2c515629ce5d5f8e2
BLAKE2b-256 f12708988bd72d52c649c13715277619ed492e5e284cc6391171447def5dcc23

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.14-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.14-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 f819a5b13e9b663686178d50a74c475d10a8fd4c5f29745d9aad7828f2d24a50
MD5 8dc138759e52fcbec9a930a6a34ac55d
BLAKE2b-256 fef8bb117c1b05c2a3659b5ab214d34a1fe568a003e255eea5301fe436684b68

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.14-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.14-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 4ae7db125c8f1e1f4abe1162dd7f0084e688ec3f34ec94d97f38ebbef8d205da
MD5 2e57c01bee9c1f63f6ec7e515121785b
BLAKE2b-256 0de06c1c56a250e5a91dbee0ade07c48affc9b0f91ec1125feebf781f45b1c08

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.14-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.14-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 4a6956aec29bc4fd2c0617d46b81c4b32d76bf2cf7e26ab9127caf5eef3cdb2b
MD5 1a475a7c7095f4cd830bd2e63d6ef402
BLAKE2b-256 f31b782c1557c3bd1ddd615fdefa93fe37689d9cff7451e5f64db548495d0e6b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.14-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.14-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 43bf2dcd524991c475e90f6e4ccaa97897cec83b96d8f894934ab28c58c70f24
MD5 2c62217e138919c874aa898cbec682be
BLAKE2b-256 8501c377bb5f68a01336de4722c2e3f202385eac7f20652faf37abcdb28a6c04

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.14-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.14-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2a33b969efe52e605c84c09c7456432bfa6c345fc1378a4a11aab75a699fc3b2
MD5 1b0feea59e204ab5f1e43db92c7758cf
BLAKE2b-256 9609bafdb808bb3bccdb651db1935654b42b4c2f769f6b3473ddb14af236aeb4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.14-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.14-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 ebd95d2d0163d322b1ebc57f92925c3c395803958d8242ce6d895b1932f0b3d5
MD5 0648425202fa6c2d879b039dd67ee8a0
BLAKE2b-256 584ad488c6cd30ed62d63b76f0930c6cf2daa21d180e3d94421353738feb1f80

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.14-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.14-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ac348bab79e12eb25bfde2755b581170c83de3c3731adf39188f27c855d8d462
MD5 6e31b6bcfb7f94a6b87de0929e387e12
BLAKE2b-256 477a7e0d8a8559e9ceef607aa4f081bd725d8fe38e7c442871596370bdd07c72

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page