Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

���� TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • ���� Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ������ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ������ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • ���� Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • ���� Serialization and memory-mapping
  • �� Functional programming and compatibility with torch.vmap
  • ���� Nesting: Nest TensorDict instances to create hierarchical structures.
  • ��� Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • ���� Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict-0.6.2-cp312-cp312-manylinux1_x86_64.whl (360.6 kB view details)

Uploaded CPython 3.12

tensordict-0.6.2-cp312-cp312-macosx_11_0_arm64.whl (668.9 kB view details)

Uploaded CPython 3.12 macOS 11.0+ ARM64

tensordict-0.6.2-cp311-cp311-manylinux1_x86_64.whl (360.4 kB view details)

Uploaded CPython 3.11

tensordict-0.6.2-cp311-cp311-macosx_11_0_arm64.whl (669.7 kB view details)

Uploaded CPython 3.11 macOS 11.0+ ARM64

tensordict-0.6.2-cp310-cp310-manylinux1_x86_64.whl (359.9 kB view details)

Uploaded CPython 3.10

tensordict-0.6.2-cp310-cp310-macosx_11_0_arm64.whl (668.4 kB view details)

Uploaded CPython 3.10 macOS 11.0+ ARM64

tensordict-0.6.2-cp39-cp39-manylinux1_x86_64.whl (359.9 kB view details)

Uploaded CPython 3.9

tensordict-0.6.2-cp39-cp39-macosx_11_0_arm64.whl (668.5 kB view details)

Uploaded CPython 3.9 macOS 11.0+ ARM64

File details

Details for the file tensordict-0.6.2-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict-0.6.2-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 6ce86d8f9ed53d1c893e0064729a799eb06371902e2d5a2f3c83c6af010cf3c2
MD5 c3e4a9d7cca76a5bc4cfd7619f0638b3
BLAKE2b-256 64aa2b562740c8be079aba4bb9eb37cb94ed042d65e48f739577eb70eee617c2

See more details on using hashes here.

File details

Details for the file tensordict-0.6.2-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for tensordict-0.6.2-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 529515ffcb0c6ab5d1c0c85798defe8972e5390412c356c79b87f4073e48b537
MD5 9eb7cba4934ad21bae1ec7d7a133f239
BLAKE2b-256 d217a2bdb51d5d1b2f06debdf8f60c159cf328c778b6cc8103940f2079e2e767

See more details on using hashes here.

File details

Details for the file tensordict-0.6.2-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict-0.6.2-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 3e036ae56532b6308a9e95de29db97f1ec9b83f569c716a602d7d48dc0b47ef5
MD5 ddf7c2f6477656c0efac22bec0380e7e
BLAKE2b-256 2904a58ed454ad7f8f8a65f6ebe2df279c4d6921234cad3603318cf8a150b0f3

See more details on using hashes here.

File details

Details for the file tensordict-0.6.2-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for tensordict-0.6.2-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 e5132b369ac9e84eb0e9e28d1954f6d1370e06927ab3c502e02a6d35a1ed629a
MD5 108987cae370b49b804c53c6c8d784f7
BLAKE2b-256 fdc4ddbd7109fef983a726cc135bdd748a7b3747cd3e77dc300ab4cc577a8d64

See more details on using hashes here.

File details

Details for the file tensordict-0.6.2-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict-0.6.2-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 24a505926cca17401411b9ed7377013c57dd49d579f350ac2e002d701a7f6cfd
MD5 b7285d44ec759fe97abeb98a458ef1be
BLAKE2b-256 96d0ecdb847d7190cca1e6dcdfbfd374fad82937b890bd4609385cdf3b4d5592

See more details on using hashes here.

File details

Details for the file tensordict-0.6.2-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for tensordict-0.6.2-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 a00aea0dd7d38026b7e4721d3e4c9759fed99829d8f3d337f4f5174c42138c87
MD5 54bc95c6d7cebcc76425c71f9f491bfa
BLAKE2b-256 a9b140afa9e2e11414b8def6fedb044eb77ebdfb17ab02b1d925339e2dce3ada

See more details on using hashes here.

File details

Details for the file tensordict-0.6.2-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict-0.6.2-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 490371099494cdb3250b77d64e5ea63da8c3715a612ffab03e5f9a791161bffd
MD5 403b25bdcf6de9d7ab1d4a45a2ffd9f6
BLAKE2b-256 2ca8031063d078e09fff8f3bed5d07707e50912fb9c4cff4f65c032eb165b432

See more details on using hashes here.

File details

Details for the file tensordict-0.6.2-cp39-cp39-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for tensordict-0.6.2-cp39-cp39-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 1bfb1d41877fc19b709a7ad20c3b36f19224e5013263989de8e994bd98559fda
MD5 5933e8950c9626b09c9c91388473e5ba
BLAKE2b-256 3bca58c7010737dc0dfea44824e1691ac679c69e93d824c907ab1145c2c0f5e9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page