Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.9.19-cp312-cp312-win_amd64.whl (348.2 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.9.19-cp311-cp311-win_amd64.whl (347.6 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.9.19-cp310-cp310-win_amd64.whl (346.7 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.9.19-cp39-cp39-win_amd64.whl (346.6 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.9.19-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.19-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 2ee2f2794c2d03cac7d84efd33320b7d45dfd2126b4c1de9d7c1cba44ba91d39
MD5 591f90a1cb8bb59f03cdc026843ba28b
BLAKE2b-256 6a675c76da1918e39436adc3de1accf70198004224bc1c4f152ccc44c3e376fd

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.19-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.19-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 811e0af3b761e5a7a7a0db43b06f2ec18b27ab328996fb7deac182df9d4186a5
MD5 c7faadea7b68704fbfb22c620b080d02
BLAKE2b-256 9d77db227e15d8d5718ca763e36bb45fc041a0fc752159de488f3e3587d81c43

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.19-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.19-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 4adec125bd10b3ed4e7bd88a655f6202304563279c30232016d7063e518b1af6
MD5 0a3e9c9bfd70185715e546d755c774df
BLAKE2b-256 2f1442beb7f7d658a5ffcf9cdd622650c9852c4c83cf991dc958d35b5c50b704

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.19-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.19-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 677e48f628ed25c0b869d69613d0adf7149504604e615a107d1b4a1187800334
MD5 922e5d23ce3b5e32c60c08f12bd734af
BLAKE2b-256 6436587443031ff502cbbfb4ab0dd4b412e98d25366b99b0c47945adc4923c95

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.19-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.19-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 f3c73860a577683534fd8dbde22484268fca9124c16de9ddcbca43228cfced9f
MD5 6cdf227972bb0e40044fff3e6652939e
BLAKE2b-256 5fff4e5b3eca2c143116f7052ba6928fbd0e588526f9227c2f847b003c07db3b

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.19-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.19-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 f5144536601321974c10bbc078067dd7c5939c9f0c0e7c0e2866db135d595faa
MD5 b35b0a95e9fe4b7139e38819736731c2
BLAKE2b-256 28140ecfb430d25b75997087848628ef3f9888d12cc39f100e469d6d79185a4f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.19-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.19-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 2e385c6d1bc8fe6592c756fd097a69729929207594c447de7095539b91dffba3
MD5 e175845f47050bda39f874ad683711e0
BLAKE2b-256 d03ec60d1bd4fd4c04dd965947246fe2ddbdd8270eca2fea51ebbd993b48fff4

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.9.19-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.9.19-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 03ac6bbb00c925e8c1524f39b522c6a002997e7bf1e1493fc5a90dea5c340af0
MD5 2c925b4b95299b8f21d7bb59f4cfa1db
BLAKE2b-256 f2323324e9055a3498bdb47d184b097e938229fcf19f9dd5aae87d3a25569d53

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page