TensorDict-like functionality for PyTorch with PyTree compatibility and torch.compile support
Project description
Tensor Container
Tensor containers for PyTorch with PyTree compatibility and torch.compile optimization
⚠️ Academic Research Project: This project exists solely for academic purposes to explore and learn PyTorch internals. For production use, please use the official, well-maintained torch/tensordict library.
Tensor Container provides efficient, type-safe tensor container implementations for PyTorch workflows. It includes PyTree integration and torch.compile optimization for batched tensor operations.
The library includes tensor containers, probabilistic distributions, and batch/event dimension semantics for machine learning workflows.
What is TensorContainer?
TensorContainer transforms how you work with structured tensor data in PyTorch by providing tensor-like operations for entire data structures. Instead of manually managing individual tensors across devices, batch dimensions, and nested hierarchies, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.
🚀 Unified Operations Across Data Types
Apply tensor operations like view(), permute(), detach(), and device transfers to entire data structures—no matter how complex:
# Single operation transforms entire distribution
distribution = distribution.view(2, 3, 4).permute(1, 0, 2).detach()
# Works seamlessly across TensorDict, TensorDataClass, and TensorDistribution
data = data.to('cuda').reshape(batch_size, -1).clone()
🔄 Drop-in Compatibility with PyTorch
TensorContainer integrates seamlessly with existing PyTorch workflows:
- torch.distributions compatibility: TensorDistribution is API-compatible with
torch.distributionswhile adding tensor-like operations - PyTree support: All containers work with
torch.utils._pytreeoperations andtorch.compile - Zero learning curve: If you know PyTorch tensors, you already know TensorContainer
⚡ Eliminates Boilerplate Code
Compare the complexity difference:
With torch.distributions (manual parameter handling):
# Requires type-specific parameter extraction and reconstruction
if isinstance(dist, Normal):
detached = Normal(loc=dist.loc.detach(), scale=dist.scale.detach())
elif isinstance(dist, Categorical):
detached = Categorical(logits=dist.logits.detach())
# ... more type checks needed
With TensorDistribution (unified interface):
# Works for any distribution type
detached = dist.detach()
🏗️ Structured Data Made Simple
Handle complex, nested tensor structures with the same ease as single tensors:
- Batch semantics: Consistent shape handling across all nested tensors
- Device management: Move entire structures between CPU/GPU with single operations
- Shape validation: Automatic verification of tensor compatibility
- Type safety: Full IDE support with static typing and autocomplete
TensorContainer doesn't just store your data—it makes working with structured tensors as intuitive as working with individual tensors, while maintaining full compatibility with the PyTorch ecosystem you already know.
Table of Contents
- Installation
- Quick Start
- Features
- API Overview
- torch.compile Compatibility
- Contributing
- Documentation
- License
- Authors
- Contact and Support
Installation
From Source (Development)
# Clone the repository
git clone https://github.com/mctigger/tensor-container.git
cd tensor-container
# Install in development mode
pip install -e .
# Install with development dependencies
pip install -e .[dev]
Requirements
- Python 3.9+
- PyTorch 2.0+
Quick Start
TensorDict: Dictionary-Style Containers
import torch
from tensorcontainer import TensorDict
# Create a TensorDict with batch semantics
data = TensorDict({
'observations': torch.randn(32, 128),
'actions': torch.randn(32, 4),
'rewards': torch.randn(32, 1)
}, shape=(32,), device='cpu')
# Dictionary-like access
obs = data['observations']
data['new_field'] = torch.zeros(32, 10)
# Batch operations work seamlessly
stacked_data = torch.stack([data, data]) # Shape: (2, 32)
TensorDataClass: Type-Safe Containers
import torch
from tensorcontainer import TensorDataClass
class RLData(TensorDataClass):
observations: torch.Tensor
actions: torch.Tensor
rewards: torch.Tensor
# Create with full type safety and IDE support
data = RLData(
observations=torch.randn(32, 128),
actions=torch.randn(32, 4),
rewards=torch.randn(32, 1),
shape=(32,),
device='cpu'
)
# Type-safe field access with autocomplete
obs = data.observations
data.actions = torch.randn(32, 8) # Type-checked assignment
TensorDistribution: Probabilistic Containers
import torch
from tensorcontainer import TensorDistribution
# Built-in distribution types
from tensorcontainer.tensor_distribution import (
TensorNormal, TensorBernoulli, TensorCategorical,
TensorTruncatedNormal, TensorTanhNormal
)
# Create probabilistic tensor containers
normal_dist = TensorNormal(
loc=torch.zeros(32, 4),
scale=torch.ones(32, 4),
shape=(32,),
device='cpu'
)
# Sample and compute probabilities
samples = normal_dist.sample() # Shape: (32, 4)
log_probs = normal_dist.log_prob(samples)
entropy = normal_dist.entropy()
# Categorical distributions for discrete actions
categorical = TensorCategorical(
logits=torch.randn(32, 6), # 6 possible actions
shape=(32,),
device='cpu'
)
PyTree Operations
# All containers work seamlessly with PyTree operations
import torch.utils._pytree as pytree
# Transform all tensors in the container
doubled_data = pytree.tree_map(lambda x: x * 2, data)
# Combine multiple containers
combined = pytree.tree_map(lambda x, y: x + y, data1, data2)
Features
- torch.compile Optimized: Compatible with PyTorch's JIT compiler
- PyTree Support: Integration with
torch.utils._pytreefor tree operations - Zero-Copy Operations: Efficient tensor sharing and manipulation
- Type Safety: Static typing support with IDE autocomplete and type checking
- Batch Semantics: Consistent batch/event dimension handling
- Shape Validation: Automatic validation of tensor shapes and device consistency
- Multiple Container Types: Different container types for different use cases
- Probabilistic Support: Distribution containers for probabilistic modeling
- Comprehensive Testing: Extensive test suite with compile compatibility verification
- Memory Efficient: Optimized memory usage with slots-based dataclasses
API Overview
Core Components
TensorContainer: Base class providing core tensor manipulation operations with batch/event dimension semanticsTensorDict: Dictionary-like container for dynamic tensor collections with nested structure supportTensorDataClass: DataClass-based container for static, typed tensor structuresTensorDistribution: Distribution wrapper for probabilistic tensor operations
Key Concepts
- Batch Dimensions: Leading dimensions defined by the
shapeparameter, consistent across all tensors - Event Dimensions: Trailing dimensions beyond batch shape, can vary per tensor
- PyTree Integration: All containers are registered PyTree nodes for seamless tree operations
- Device Consistency: Automatic validation ensures all tensors reside on compatible devices
- Unsafe Construction: Context manager for performance-critical scenarios with validation bypass
torch.compile Compatibility
Tensor Container is designed for torch.compile compatibility:
@torch.compile
def process_batch(data: TensorDict) -> TensorDict:
# PyTree operations compile efficiently
return TensorContainer._tree_map(lambda x: torch.relu(x), data)
@torch.compile
def sample_and_score(dist: TensorNormal, actions: torch.Tensor) -> torch.Tensor:
# Distribution operations are compile-safe
return dist.log_prob(actions)
# All operations compile efficiently with minimal graph breaks
compiled_result = process_batch(tensor_dict)
log_probs = sample_and_score(normal_dist, action_tensor)
The testing framework includes compile compatibility verification to ensure operations work efficiently under JIT compilation, including:
- Graph break detection and minimization
- Recompilation tracking
- Memory leak prevention
- Performance benchmarking
Contributing
Contributions are welcome! Tensor Container is a learning project for exploring PyTorch internals and tensor container implementations.
Development Setup
# Clone and install in development mode
git clone https://github.com/mctigger/tensor-container.git
cd tensor-container
pip install -e .[dev]
Running Tests
# Run all tests with coverage
pytest --strict-markers --cov=src
# Run specific test modules
pytest tests/tensor_dict/test_compile.py
pytest tests/tensor_dataclass/
pytest tests/tensor_distribution/
# Run compile-specific tests
pytest tests/tensor_dict/test_graph_breaks.py
pytest tests/tensor_dict/test_recompilations.py
Development Guidelines
- All new features must maintain
torch.compilecompatibility - Comprehensive tests required, including compile compatibility verification
- Follow existing code patterns and typing conventions
- Distribution implementations must support KL divergence registration
- Memory efficiency considerations for large-scale tensor operations
- Unsafe construction patterns for performance-critical paths
Contribution Process
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Make your changes with appropriate tests
- Ensure all tests pass and maintain coverage
- Submit a pull request with a clear description
Documentation
The project includes documentation:
docs/compatibility.md: Python version compatibility guide and best practicesdocs/testing.md: Testing philosophy, standards, and guidelines- Source Code Documentation: Extensive docstrings and type annotations throughout the codebase
- Test Coverage: 643+ tests covering all major functionality with 86% code coverage
License
This project is licensed under the MIT License - see the LICENSE file for details.
Authors
- Tim Joseph - mctigger
Contact and Support
- Issues: Report bugs and request features on GitHub Issues
- Discussions: Join conversations on GitHub Discussions
- Email: For direct inquiries, contact tim@mctigger.com
Tensor Container is an academic research project for learning PyTorch internals and tensor container patterns. For production applications, we strongly recommend using the official torch/tensordict library, which is actively maintained by the PyTorch team.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tensorcontainer-0.6.3.dev20250805062736.tar.gz.
File metadata
- Download URL: tensorcontainer-0.6.3.dev20250805062736.tar.gz
- Upload date:
- Size: 49.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
46346ee1385805ed4284c750298d30dc23592d0447a548b28bf8b688525e7ce6
|
|
| MD5 |
362db923c924c00aaa46e172a9478d55
|
|
| BLAKE2b-256 |
c7c77b0d7c35214c4eea19cb868aafb814e79c53df35d01da65a3dcc03ffc840
|
Provenance
The following attestation bundles were made for tensorcontainer-0.6.3.dev20250805062736.tar.gz:
Publisher:
publish-nightly.yml on mctigger/tensorcontainer
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tensorcontainer-0.6.3.dev20250805062736.tar.gz -
Subject digest:
46346ee1385805ed4284c750298d30dc23592d0447a548b28bf8b688525e7ce6 - Sigstore transparency entry: 350929434
- Sigstore integration time:
-
Permalink:
mctigger/tensorcontainer@1d7d2d2d78c18b4a0a95c93b52628f75f28044ed -
Branch / Tag:
refs/heads/main - Owner: https://github.com/mctigger
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-nightly.yml@1d7d2d2d78c18b4a0a95c93b52628f75f28044ed -
Trigger Event:
push
-
Statement type:
File details
Details for the file tensorcontainer-0.6.3.dev20250805062736-py3-none-any.whl.
File metadata
- Download URL: tensorcontainer-0.6.3.dev20250805062736-py3-none-any.whl
- Upload date:
- Size: 70.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
10eb545d6272349fb67f8b622ceef0f300e8c1d4eb00281892a2c9b9ece20a88
|
|
| MD5 |
675b9975ba98ba17b1d9e81f842e44ac
|
|
| BLAKE2b-256 |
6250f56b1252759fcb2c02b7af4c7906b3740792ecfbb26b62bbf0f99461b6c4
|
Provenance
The following attestation bundles were made for tensorcontainer-0.6.3.dev20250805062736-py3-none-any.whl:
Publisher:
publish-nightly.yml on mctigger/tensorcontainer
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tensorcontainer-0.6.3.dev20250805062736-py3-none-any.whl -
Subject digest:
10eb545d6272349fb67f8b622ceef0f300e8c1d4eb00281892a2c9b9ece20a88 - Sigstore transparency entry: 350929447
- Sigstore integration time:
-
Permalink:
mctigger/tensorcontainer@1d7d2d2d78c18b4a0a95c93b52628f75f28044ed -
Branch / Tag:
refs/heads/main - Owner: https://github.com/mctigger
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-nightly.yml@1d7d2d2d78c18b4a0a95c93b52628f75f28044ed -
Trigger Event:
push
-
Statement type: