TensorDict-like functionality for PyTorch with PyTree compatibility and torch.compile support
Project description
Tensor Container
Tensor containers for PyTorch with PyTree compatibility and torch.compile optimization
Tensor Container provides efficient, type-safe tensor container implementations for PyTorch workflows. It includes PyTree integration and torch.compile optimization for batched tensor operations.
The library includes tensor containers (dict, dataclass) and distributions (torch.distributions equivalent).
What is TensorContainer?
TensorContainer transforms how you work with structured tensor data in PyTorch by providing tensor-like operations for entire data structures. Instead of manually managing individual tensors, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.
Core Benefits
- Unified Operations: Apply tensor operations like
view(),permute(),detach(), and device transfers to entire data structures - Drop-in Compatibility: Seamless integration with existing PyTorch workflows and
torch.compile - Zero Boilerplate: Eliminate manual parameter handling and type-specific operations
- Type Safety: Full IDE support with static typing and autocomplete
data = TensorDict(
{"a": torch.rand(24), "b": torch.rand(24)},
shape=(24,),
device="cpu"
)
# Single operation transforms entire structure
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()
Key Features
- ⚡ JIT Compilation: Designed for
torch.compilewithfullgraph=True, minimizing graph breaks and maximizing performance - 📐 Batch/Event Semantics: Clear distinction between batch dimensions (consistent across tensors) and event dimensions (tensor-specific)
- 🔄 Device Management: Move entire structures between CPU/GPU with single operations and flexible device compatibility
- 🔒 Type Safety: Full IDE support with static typing and autocomplete
- 🏗️ Multiple Container Types: Three specialized containers for different use cases:
TensorDictfor dynamic, dictionary-style data collectionsTensorDataClassfor type-safe, dataclass-based structuresTensorDistributionfor probabilistic modeling with 40+ probability distributions
- 🔧 Advanced Operations: Full PyTorch tensor operations support including
view(),permute(),stack(),cat(), and more - 🎯 Advanced Indexing: Complete PyTorch indexing semantics with boolean masks, tensor indices, and ellipsis support
- 📊 Shape Validation: Automatic verification of tensor compatibility with detailed error messages
- 🌳 Nested Structure Support: Create nested structure with different TensorContainers
Table of Contents
- What is TensorContainer?
- Installation
- Quick Start
- Features
- API Overview
- torch.compile Compatibility
- Examples
- Contributing
- Documentation
- License
- Authors
- Contact and Support
Installation
Using pip
pip install tensorcontainer
Requirements
- Python 3.9+
- PyTorch 2.6+
Quick Start
TensorContainer transforms how you work with structured tensor data. Instead of managing individual tensors, you can treat entire data structures as unified entities that behave like regular tensors.
# Single operation transforms entire structure
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()
1. TensorDict: Dynamic Data Collections
Perfect for reinforcement learning data and dynamic collections:
import torch
from tensorcontainer import TensorDict
# Create a container for RL training data
data = TensorDict({
'observations': torch.randn(32, 128),
'actions': torch.randn(32, 4),
'rewards': torch.randn(32, 1)
}, shape=(32,))
# Dictionary-like access with tensor operations
obs = data['observations']
data['advantages'] = torch.randn(32, 1) # Add new fields dynamically
# Batch operations work seamlessly
batch = torch.stack([data, data]) # Shape: (2, 32)
2. TensorDataClass: Type-Safe Structures
Ideal for model inputs and structured data with compile-time safety:
import torch
from tensorcontainer import TensorDataClass
class ModelInput(TensorDataClass):
features: torch.Tensor
labels: torch.Tensor
# Create with full type safety and IDE support
batch = ModelInput(
features=torch.randn(32, 64, 784),
labels=torch.randint(0, 10, (32, 64)),
shape=(32, 64)
)
# Unified operations on entire structure - reshape all tensors at once
batch = batch.view(2048)
# Type-safe access with autocomplete works on reshaped data too
loss = torch.nn.functional.cross_entropy(batch.features, batch.labels)
3. TensorDistribution: Probabilistic Modeling
Streamline probabilistic computations in reinforcement learning and generative models:
import torch
from tensorcontainer.tensor_distribution import TensorNormal
normal = TensorNormal(
loc=torch.zeros(100, 4),
scale=torch.ones(100, 4)
)
# With torch.distributions we need to extract the parameters, detach them
# and create a new Normal distribution. With TensorDistribution we just call
# .detach() on the distribution. We can also apply other tensor operations,
# such as .view()!
detached_normal = normal.detach()
Documentation
The project includes comprehensive documentation:
docs/user_guide/overview.md: Complete user guide with examples and best practicesdocs/developer_guide/compatibility.md: Python version compatibility guide and best practicesdocs/developer_guide/testing.md: Testing philosophy, standards, and guidelines- Source Code Documentation: Extensive docstrings and type annotations throughout the codebase
- Test Coverage: 643+ tests covering all major functionality with 86% code coverage
License
This project is licensed under the MIT License - see the LICENSE file for details.
Authors
- Tim Joseph - mctigger
Contact and Support
- Issues: Report bugs and request features on GitHub Issues
- Discussions: Join conversations on GitHub Discussions
- Email: For direct inquiries, contact tim@mctigger.com
Tensor Container is an academic research project for learning PyTorch internals and tensor container patterns. For production applications, we strongly recommend using the official torch/tensordict library, which is actively maintained by the PyTorch team.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tensorcontainer-0.8.3.dev20260225154052.tar.gz.
File metadata
- Download URL: tensorcontainer-0.8.3.dev20260225154052.tar.gz
- Upload date:
- Size: 64.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2b22de98ee233c61a3112ff2d1641b5933cae5e2452e17c16892dada12b9653b
|
|
| MD5 |
2c6808e05bdaa60317f93c9877b79998
|
|
| BLAKE2b-256 |
d328acdc2befda74302648b046d1cd138689640743ddcbf4666b92b92068d000
|
Provenance
The following attestation bundles were made for tensorcontainer-0.8.3.dev20260225154052.tar.gz:
Publisher:
publish-nightly.yml on mctigger/tensorcontainer
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tensorcontainer-0.8.3.dev20260225154052.tar.gz -
Subject digest:
2b22de98ee233c61a3112ff2d1641b5933cae5e2452e17c16892dada12b9653b - Sigstore transparency entry: 991861044
- Sigstore integration time:
-
Permalink:
mctigger/tensorcontainer@77ec6947591a2591f82987ba2210d6d904694dd0 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/mctigger
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-nightly.yml@77ec6947591a2591f82987ba2210d6d904694dd0 -
Trigger Event:
push
-
Statement type:
File details
Details for the file tensorcontainer-0.8.3.dev20260225154052-py3-none-any.whl.
File metadata
- Download URL: tensorcontainer-0.8.3.dev20260225154052-py3-none-any.whl
- Upload date:
- Size: 93.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
178648696a983c68c00bb5237e49073e084acfcbbbd842b87c4453f7083259ce
|
|
| MD5 |
4c31668f1138ee1c201a25bc106eae7c
|
|
| BLAKE2b-256 |
1b9b0b76c4a7478bf7e7e16e4a0e0a544d3069331324e22f8ced93757922f35c
|
Provenance
The following attestation bundles were made for tensorcontainer-0.8.3.dev20260225154052-py3-none-any.whl:
Publisher:
publish-nightly.yml on mctigger/tensorcontainer
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tensorcontainer-0.8.3.dev20260225154052-py3-none-any.whl -
Subject digest:
178648696a983c68c00bb5237e49073e084acfcbbbd842b87c4453f7083259ce - Sigstore transparency entry: 991861046
- Sigstore integration time:
-
Permalink:
mctigger/tensorcontainer@77ec6947591a2591f82987ba2210d6d904694dd0 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/mctigger
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-nightly.yml@77ec6947591a2591f82987ba2210d6d904694dd0 -
Trigger Event:
push
-
Statement type: