Skip to main content

A package for caching PyTorch modules

Project description

Torch Module Cache

A PyTorch module caching decorator that enables efficient caching of module outputs, with support for both single inference and batch processing.

Features

  • Cache PyTorch module outputs to disk or memory
  • Support for both single inputs and batched inputs
  • Automatic smart batching for performance optimization
  • Safe loading options for improved security
  • Memory cache for ultra-fast repeated access
  • Configurable cache paths and naming

Installation

# Clone the repository
git clone https://github.com/yourusername/torch-module-cache.git
cd torch-module-cache

# Install the package
pip install -e .

Basic Usage

Simple Example

import torch
import torch.nn as nn
from torch_module_cache import cache_module

@cache_module()
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize your model here
        self.linear = nn.Linear(10, 5)
        
    def forward(self, x, cache_key=None):
        # The cache_key parameter is injected by the decorator
        # When provided, results will be cached
        return self.linear(x)

# Create model instance
model = MyModel()

# Normal forward pass (no caching)
input_tensor = torch.randn(1, 10)
output = model(input_tensor)

# Cached forward pass (first time will compute and cache)
output_cached = model(input_tensor, cache_key="my_unique_key")

# Subsequent calls with the same key will load from cache
output_from_cache = model(input_tensor, cache_key="my_unique_key")

Batch Processing

The decorator supports batched inference, which can significantly improve performance when processing multiple inputs:

import torch
import torch.nn as nn
from torch_module_cache import cache_module

@cache_module()
class BatchModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(10, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        
    def forward(self, x, cache_key=None):
        return self.encoder(x)

# Create model instance
model = BatchModel()

# Create a batch of inputs
batch_size = 4
batch_input = torch.randn(batch_size, 10)

# Create a list of cache keys (one for each item in the batch)
batch_keys = ["item1", "item2", "item3", "item4"]

# Process the entire batch with unique keys for each item
# The decorator will handle caching each result individually
batch_output = model(batch_input, cache_key=batch_keys)

# The next time you use the same keys, results will be loaded from cache
cached_batch_output = model(batch_input, cache_key=batch_keys)

Partial Cache Hits

One of the key features is the ability to handle partial cache hits efficiently:

# Some keys are already cached, some are new
mixed_keys = ["item1", "item2", "new_item1", "new_item2"]

# Only the new items will be processed, cached items will be loaded from cache
mixed_output = model(batch_input, cache_key=mixed_keys)

Configuration Options

The @cache_module() decorator accepts several configuration parameters:

@cache_module(
    # Path to store cache files (default: ~/.cache/torch-module-cache)
    cache_path="/path/to/cache",
    
    # Subfolder name for this specific model (default: class name)
    cache_name="my_model_cache",
    
    # Cache level: CacheLevel.DISK or CacheLevel.MEMORY
    cache_level=CacheLevel.MEMORY,
    
    # Whether to use safer loading options (recommended for untrusted data)
    safe_load=True
)

Cache Management

from torch_module_cache import clear_memory_caches, clear_disk_caches

# Clear all in-memory caches
clear_memory_caches()

# Clear all disk caches
clear_disk_caches()

# Clear caches for a specific model
clear_memory_caches(cache_name="my_model_cache")
clear_disk_caches(cache_name="my_model_cache")

Performance Considerations

  • Memory vs. Disk Caching: Memory caching is much faster but limited by available RAM
  • Batch Processing: Processing inputs in batches is typically much faster than individual processing
  • Cache Keys: Choose unique and meaningful cache keys that represent your inputs
  • Cache Path: For large models, ensure the cache path has sufficient disk space

Advanced Usage

Custom Cache Path

# Custom cache path
@cache_module(cache_path="/tmp/my_model_cache")
class CustomPathModel(nn.Module):
    # ...

Batch Processing with Mixed Types

# The decorator handles various input types
# Cache keys can be strings, numbers, or any hashable types
model_output = model(inputs, cache_key=[1, 2, 3, 4])

Example Scripts

The package includes several example scripts in the examples/ directory:

  • basic_usage.py: Simple example showing basic caching functionality
  • batch_usage.py: Demonstrates batch processing and performance comparison

Notes

  • The first forward pass with a specific cache key will always execute the model
  • For best performance with batches, try to reuse the same batch size and structure
  • Non-tensor inputs and outputs are supported but may have serialization limitations

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_module_cache-0.1.0.tar.gz (12.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_module_cache-0.1.0-py3-none-any.whl (10.9 kB view details)

Uploaded Python 3

File details

Details for the file torch_module_cache-0.1.0.tar.gz.

File metadata

  • Download URL: torch_module_cache-0.1.0.tar.gz
  • Upload date:
  • Size: 12.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for torch_module_cache-0.1.0.tar.gz
Algorithm Hash digest
SHA256 172968a7e6a2a1d1f00807ce83f724fe6d7dc84e57c52fda06b51f0de855f54f
MD5 aeddadd13bcf1535b9aff36c986bfc85
BLAKE2b-256 7b4fd4c9b817ba0fa4a88db354304b09d69e74f93443802e593dc0c7f3e4bfcd

See more details on using hashes here.

File details

Details for the file torch_module_cache-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_module_cache-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b45e9b28c69f45b7b8f682c3855115f6784da8ec811049ee4adfd833be1268f0
MD5 0f6358a685ad660222bcd1db0ad3389d
BLAKE2b-256 b4cd2bd20130481ab31b50ac42e637c35b7f9b1698822220a9105716a4cbcf37

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page