Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.8.30-cp312-cp312-win_amd64.whl (331.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.8.30-cp311-cp311-win_amd64.whl (330.8 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.8.30-cp310-cp310-win_amd64.whl (330.1 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.8.30-cp39-cp39-win_amd64.whl (329.8 kB view details)

Uploaded CPython 3.9 Windows x86-64

tensordict_nightly-2024.8.30-cp38-cp38-win_amd64.whl (330.2 kB view details)

Uploaded CPython 3.8 Windows x86-64

File details

Details for the file tensordict_nightly-2024.8.30-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.30-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 6bd47ea90f3c1af3c1d0fa51a418839d2491627fd96bd3075831c6e85f559747
MD5 d4c079ac1eb6e8e6990d0f3b8f1dde30
BLAKE2b-256 15b5b5a567f0e22c6a26c87685ca19b1c66b906e804ca78004628de141877cf0

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.30-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.30-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 6669ea6fe29904a9785dbf01c0c91278477649516fc83cddedf40931535c021b
MD5 1aec89a4cb77f79eed9b63dbab12d261
BLAKE2b-256 f40b62fe7c5caf6021419c00f3f16f389a321031c712e27368a9aac4e01d36ed

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.30-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.30-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 867f2c1bc9968a72365e7504731a45976bc076072f4b06ea6b0b81dce0cb96a0
MD5 1263a12bf390975c41c16c4844059e1c
BLAKE2b-256 4064fd331d25af521210c6998de22844b4016952cafe10a72cd2b31a3c4b32f7

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.30-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.30-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7ff0b18661f2c188a5dbe8c9d65a606b8049d60f1aa18c0ab93d42f3845b6185
MD5 823a3791338609ae863ab5139349fd87
BLAKE2b-256 dda8c98f881b16ab1b17dd1dfeec3fa70c8e6b3eeda72e1002bbd48ed48b884e

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.30-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.30-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 d48d19a85860cb952502469e5b809c2f3dc9960cae9c968fd64ff0d2e2f128d6
MD5 224e931488b1f53f862282f6f04ca05a
BLAKE2b-256 96aa562d35771f37572858666dd878455ebb2ddf7e0ce3694a4b73d213befa81

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.30-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.30-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 443cb1caf60e090776f8805d53f23991c6cb2d7065bff884706742c85f177e88
MD5 1c0c802bb4749d85cca91cf21c8a5ca5
BLAKE2b-256 ed9fc09680b3dc2c5e4fa73f26e004fafbb5bc597829f7fb03914b15514ab8fb

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.30-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.30-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 54975046231e918f49cc38c12564761566ee08e21470ed7f605f2d4a82286f04
MD5 0a6aa1f2075f4742a8d29d416f6ef5c4
BLAKE2b-256 be09f8ad6f94afdd93e28f145a8b142cb9194c1cd780a76b9b9de1c199613900

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.30-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.30-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ed6fa0db395ec105e056aaaacd1ae881e9953d1e8989875c6c1197373b9286ab
MD5 3e4976ced3233329f9c7e9dcf10ac877
BLAKE2b-256 18107269cb218d7e3d14c660720ce8192c7c37aaa98aa23f164e51f20a4037c5

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.30-cp38-cp38-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.30-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 180f7555e8265b8f66489a13fdf543b5f4ba7bdfbdb9039dccf51f087c8ed761
MD5 44810c1f23f849b50c6e1bb90ce3da45
BLAKE2b-256 45a2c73ab3517b5d2fb012ef564d20df572c9d430cd22df6aee0eed6917c7dd9

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.8.30-cp38-cp38-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.8.30-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 52a4b50ec431e4af1a1be1ad736193af61d4d4f90a09a2bb4ba684357882e347
MD5 dd17ec356d10369d31c064b15e796f55
BLAKE2b-256 6487f6de0b055c1b7035e3858300343a9a2a7f6941ac6cfc155cd99b3036c88e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page