Skip to main content

TensorDict-like functionality for PyTorch with PyTree compatibility and torch.compile support

Project description

Tensor Container

Tensor containers for PyTorch with PyTree compatibility and torch.compile optimization

Docs Documentation Python 3.9, 3.10, 3.11, 3.12 License: MIT PyTorch pypi version

Tensor Container provides efficient, type-safe tensor container implementations for PyTorch workflows. It includes PyTree integration and torch.compile optimization for batched tensor operations.

The library includes tensor containers (dict, dataclass) and distributions (torch.distributions equivalent).

What is TensorContainer?

TensorContainer transforms how you work with structured tensor data in PyTorch by providing tensor-like operations for entire data structures. Instead of manually managing individual tensors, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.

Core Benefits

  • Unified Operations: Apply tensor operations like view(), permute(), detach(), and device transfers to entire data structures
  • Drop-in Compatibility: Seamless integration with existing PyTorch workflows and torch.compile
  • Zero Boilerplate: Eliminate manual parameter handling and type-specific operations
  • Type Safety: Full IDE support with static typing and autocomplete
data = TensorDict(
   {"a": torch.rand(24), "b": torch.rand(24)}, 
   shape=(24,), 
   device="cpu"
)

# Single operation transforms entire structure
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()

Key Features

  • ⚡ JIT Compilation: Designed for torch.compile with fullgraph=True, minimizing graph breaks and maximizing performance
  • 📐 Batch/Event Semantics: Clear distinction between batch dimensions (consistent across tensors) and event dimensions (tensor-specific)
  • 🔄 Device Management: Move entire structures between CPU/GPU with single operations and flexible device compatibility
  • 🔒 Type Safety: Full IDE support with static typing and autocomplete
  • 🏗️ Multiple Container Types: Three specialized containers for different use cases:
    • TensorDict for dynamic, dictionary-style data collections
    • TensorDataClass for type-safe, dataclass-based structures
    • TensorDistribution for probabilistic modeling with 40+ probability distributions
  • 🔧 Advanced Operations: Full PyTorch tensor operations support including view(), permute(), stack(), cat(), and more
  • 🎯 Advanced Indexing: Complete PyTorch indexing semantics with boolean masks, tensor indices, and ellipsis support
  • 📊 Shape Validation: Automatic verification of tensor compatibility with detailed error messages
  • 🌳 Nested Structure Support: Create nested structure with different TensorContainers

Table of Contents

Installation

Using pip

pip install tensorcontainer

Requirements

  • Python 3.9+
  • PyTorch 2.6+

Quick Start

TensorContainer transforms how you work with structured tensor data. Instead of managing individual tensors, you can treat entire data structures as unified entities that behave like regular tensors.

# Single operation transforms entire structure
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()

1. TensorDict: Dynamic Data Collections

Perfect for reinforcement learning data and dynamic collections:

import torch
from tensorcontainer import TensorDict

# Create a container for RL training data
data = TensorDict({
    'observations': torch.randn(32, 128),
    'actions': torch.randn(32, 4),
    'rewards': torch.randn(32, 1)
}, shape=(32,))

# Dictionary-like access with tensor operations
obs = data['observations']
data['advantages'] = torch.randn(32, 1)  # Add new fields dynamically

# Batch operations work seamlessly
batch = torch.stack([data, data])  # Shape: (2, 32)

2. TensorDataClass: Type-Safe Structures

Ideal for model inputs and structured data with compile-time safety:

import torch
from tensorcontainer import TensorDataClass

class ModelInput(TensorDataClass):
    features: torch.Tensor
    labels: torch.Tensor

# Create with full type safety and IDE support
batch = ModelInput(
    features=torch.randn(32, 64, 784),
    labels=torch.randint(0, 10, (32, 64)),
    shape=(32, 64)
)

# Unified operations on entire structure - reshape all tensors at once
batch = batch.view(2048)

# Type-safe access with autocomplete works on reshaped data too
loss = torch.nn.functional.cross_entropy(batch.features, batch.labels)

3. TensorDistribution: Probabilistic Modeling

Streamline probabilistic computations in reinforcement learning and generative models:

import torch
from tensorcontainer.tensor_distribution import TensorNormal

normal = TensorNormal(
    loc=torch.zeros(100, 4),
    scale=torch.ones(100, 4)  
)

# With torch.distributions we need to extract the parameters, detach them
# and create a new Normal distribution. With TensorDistribution we just call
# .detach() on the distribution. We can also apply other tensor operations,
# such as .view()!
detached_normal = normal.detach()

Documentation

The project includes comprehensive documentation:

License

This project is licensed under the MIT License - see the LICENSE file for details.

Authors

Contact and Support


Tensor Container is an academic research project for learning PyTorch internals and tensor container patterns. For production applications, we strongly recommend using the official torch/tensordict library, which is actively maintained by the PyTorch team.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensorcontainer-0.8.1.tar.gz (53.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tensorcontainer-0.8.1-py3-none-any.whl (78.0 kB view details)

Uploaded Python 3

File details

Details for the file tensorcontainer-0.8.1.tar.gz.

File metadata

  • Download URL: tensorcontainer-0.8.1.tar.gz
  • Upload date:
  • Size: 53.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for tensorcontainer-0.8.1.tar.gz
Algorithm Hash digest
SHA256 a2c19af63ab3c46aaa0b34f03b8e79a2f6c61f9e7f1a197ec499f1e4900ceac9
MD5 fc9554cb2310efdac26399b197ecbac1
BLAKE2b-256 f6f12cf6eae2c5c5c329607865bafc00b287e6a2a0c462ebfce3c3bb889aa18e

See more details on using hashes here.

Provenance

The following attestation bundles were made for tensorcontainer-0.8.1.tar.gz:

Publisher: publish-release.yml on mctigger/tensorcontainer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file tensorcontainer-0.8.1-py3-none-any.whl.

File metadata

File hashes

Hashes for tensorcontainer-0.8.1-py3-none-any.whl
Algorithm Hash digest
SHA256 db6b7da680b7fb470e9d432ea7eac25a6b7d497945163eba93dec838c3c1a0a4
MD5 5ec18522795f776b2d9998e10e4b4973
BLAKE2b-256 e4093733b75d02b93afc1716bb4b8cae6e639b8ab9b58d6f7ae8468fdeb4f287

See more details on using hashes here.

Provenance

The following attestation bundles were made for tensorcontainer-0.8.1-py3-none-any.whl:

Publisher: publish-release.yml on mctigger/tensorcontainer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page