Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.9.6-cp312-cp312-win_amd64.whl (332.4 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.9.6-cp311-cp311-win_amd64.whl (331.7 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.9.6-cp310-cp310-win_amd64.whl (331.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.9.6-cp39-cp39-win_amd64.whl (330.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.9.6-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.6-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 c31bce1d736bac429a35e54fc1baf10e2eed9a08e7fd6f76332634346e72c6c7
MD5 375c8b9cd021d7413214a380501c8a50
BLAKE2b-256 5033dcd987db3b07cd9b969d831353b575592cbdd701e5c8db49304f44e6dca5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.6-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.6-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 a15c5310c338fe8d39ef41da5d26f553199f9abd5145c7812febc895bc548746
MD5 08b86247bbb601530a84553acf05a35b
BLAKE2b-256 1f0bdd96142f633f16d3e84a30c1e6a9b5fc0a0bb09e867ba6c7a206e835127c

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.6-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.6-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 d9644cdeb8bdfc6a83a03d042e10881f0574d921196d18168f25f254ec63a5e3
MD5 18a03a597c34f7ace40a2f4b862ad01b
BLAKE2b-256 bf3fd0f8f6f9ba27594c27b446bec4256ecb5395d8d4789de0656ed985509fc9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.6-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.6-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 243e0825b362a2d6347173161dd7b04da1dc4b8b6dda104ebef8904d6ba94675
MD5 09d15e09954f7577739c2ebf0cb72cfa
BLAKE2b-256 ec02612f5ead6f04ce65fad36fd28153dc027dcdee699e55912ef2eba7c59dae

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.6-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.6-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 342a45f1c68a941eeb2541d5d990374a307e741a8b44f3447401615b706c717e
MD5 17e452fe8b4b65e2aec9e275924113d4
BLAKE2b-256 3c4abe9ffc8aac5b598839a350138aa5a003c6e7a5e401098a489aa7d2041823

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.6-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.6-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2dba40cc3b520dbff9ff2271c1a6031544f2801029caf6d28539a8685e774659
MD5 4181f99e2e08fc4baaec5f502a805597
BLAKE2b-256 a256ae42c65f000a042fd03e78e08c05f2f616d01015956f8b5845a08e8aec1e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.6-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.6-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 0c94c3ec1619033389b94bc05ea435c88558941e46e365bb2c9390014fa7de22
MD5 a73cfdd697240a5253fe75add31191c9
BLAKE2b-256 256d7a67ab7ef02b24e80cc07ff0f4c01bee13af29b8e5b5133aab1c0ec8b8eb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.6-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.6-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 4034afb5bf1089587c2ab69c89da5b9631b3e658bc2842ddcd533349f6dc4aa5
MD5 2de2846dbfc668e9cd8f07c5c09ab454
BLAKE2b-256 f64f2ad9e19bd44c77bc17f640a37c753cf1dbaca7a29c46accec3adb8b217be

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page