Skip to main content

No project description provided

Project description

Docs - GitHub.io Discord Shield Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

📖 TensorDict

TensorDict is a dictionary-like class that inherits properties from tensors, making it easy to work with collections of tensors in PyTorch. It provides a simple and intuitive way to manipulate and process tensors, allowing you to focus on building and training your models.

Key Features | Examples | Installation | Citation | License

Key Features

TensorDict makes your code-bases more readable, compact, modular and fast. It abstracts away tailored operations, making your code less error-prone as it takes care of dispatching the operation on the leaves for you.

The key features are:

  • 🧮 Composability: TensorDict generalizes torch.Tensor operations to collection of tensors.
  • ⚡️ Speed: asynchronous transfer to device, fast node-to-node communication through consolidate, compatible with torch.compile.
  • ✂️ Shape operations: Perform tensor-like operations on TensorDict instances, such as indexing, slicing or concatenation.
  • 🌐 Distributed / multiprocessed capabilities: Easily distribute TensorDict instances across multiple workers, devices and machines.
  • 💾 Serialization and memory-mapping
  • λ Functional programming and compatibility with torch.vmap
  • 📦 Nesting: Nest TensorDict instances to create hierarchical structures.
  • Lazy preallocation: Preallocate memory for TensorDict instances without initializing the tensors.
  • 📝 Specialized dataclass for torch.Tensor (@tensorclass)

tensordict.png

Examples

This section presents a couple of stand-out applications of the library. Check our Getting Started guide for an overview of TensorDict's features!

Fast copy on device

TensorDict optimizes transfers from/to device to make them safe and fast. By default, data transfers will be made asynchronously and synchronizations will be called whenever needed.

# Fast and safe asynchronous copy to 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# Fast and safe asynchronous copy to 'cpu'
td_cpu = td_cuda.to("cpu")
# Force synchronous copy
td_cpu = td_cuda.to("cpu", non_blocking=False)

Coding an optimizer

For instance, using TensorDict you can code the Adam optimizer as you would for a single torch.Tensor and apply that to a TensorDict input as well. On cuda, these operations will rely on fused kernels, making it very fast to execute:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

Training a model

Using tensordict primitives, most supervised training loops can be rewritten in a generic way:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

tensordict_nightly-2024.10.8-cp312-cp312-win_amd64.whl (347.0 kB view details)

Uploaded CPython 3.12 Windows x86-64

tensordict_nightly-2024.10.8-cp311-cp311-win_amd64.whl (346.4 kB view details)

Uploaded CPython 3.11 Windows x86-64

tensordict_nightly-2024.10.8-cp310-cp310-win_amd64.whl (345.4 kB view details)

Uploaded CPython 3.10 Windows x86-64

tensordict_nightly-2024.10.8-cp39-cp39-win_amd64.whl (345.3 kB view details)

Uploaded CPython 3.9 Windows x86-64

File details

Details for the file tensordict_nightly-2024.10.8-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.8-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 bb0efa115cb7cd338e8351c02a7cc93ec642babb5e08e498c950b26d4670f59b
MD5 3bdc627ea49b99e00ee9638b05dde239
BLAKE2b-256 5553def7a3998a1837ae33097b94970ad89d34c2efac152d9fd3a77fb8d2e36f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.8-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.8-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 3b39a559bf34794c029f1187ebc99fba31b71aca8ea5ea2fccf821a19fff2782
MD5 4136f528ab35b52bf3351ba191f7fd2b
BLAKE2b-256 f836885ff88325ecd7984d5fa4cdb79cb31e9fa31de69f1915881a59d0066b69

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.8-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.8-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 10d21868f1cf73c20741298085eb695f4d57389c7cc892c7d9f5a1bac85a5b4c
MD5 6df6be668e2c914700d995c17d5074fc
BLAKE2b-256 1b7da6ad71f446362a6a9bbb0417199a76bd5518b9eb0377c89a228324e66c9f

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.8-cp311-cp311-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.8-cp311-cp311-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2d21612a235fdcc65f0100bcd91524be9ef12196f6403225fa14c84ba7229108
MD5 54bdeef188d365dfcf5007fd0d15f8eb
BLAKE2b-256 fcbc6ec4e9518b07428cc1f29f369a524da9e7e97f2e531081a7c23e1db8537a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.8-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.8-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 d915534e5574938d4b9dbd4b59cc0b8597724c32554129247bd8c4959785ba04
MD5 5c664416e239aecde876d78d850c38e3
BLAKE2b-256 ecc5203ce642fffff4ad906f503359a66e2303bebaf2b20d8061fd8f4cd5ca35

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.8-cp310-cp310-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.8-cp310-cp310-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 9ee19a2c0dbb94ddf2b59d98be1ca97acc09942666d5e5266af1e76aefdada7c
MD5 407a6de188df0be8572e9220cde37075
BLAKE2b-256 13d1e887d1595edb1dbf54a8b70bc711a23d611d6aead1c487e46178cd3e535a

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.8-cp39-cp39-win_amd64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.8-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 08c9934ab8d3753be607ec308264a45c75b70cd80b28eb1067c42c8ca04aa238
MD5 65dd1df8cb99f770aab1547afd4fd2f5
BLAKE2b-256 bd4bbf9ee23a393035cc4bbd60c64d523c199e26f00cdaf28d28dab7efd4c6aa

See more details on using hashes here.

File details

Details for the file tensordict_nightly-2024.10.8-cp39-cp39-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for tensordict_nightly-2024.10.8-cp39-cp39-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b07db0b9122dd912ba681a01b668b0ffac9aa7cd37523c6b2df4ff41a596c40f
MD5 d0abeb7c0d2c0e7c1e5b9bd976af9d08
BLAKE2b-256 28748a3d278520e1dcb6b791702db9568faf2a6bd2dc6afbe36a1d69f79c9861

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page