A package for caching PyTorch modules
Project description
Torch Module Cache
A PyTorch module caching decorator that enables efficient caching of module outputs, with support for both single inference and batch processing.
Features
- Cache PyTorch module outputs to disk or memory
- Support for both single inputs and batched inputs
- Automatic smart batching for performance optimization
- Safe loading options for improved security
- Memory cache for ultra-fast repeated access
- Configurable cache paths and naming
Installation
# Clone the repository
git clone https://github.com/yourusername/torch-module-cache.git
cd torch-module-cache
# Install the package
pip install -e .
Basic Usage
Simple Example
import torch
import torch.nn as nn
from torch_module_cache import cache_module
@cache_module()
class MyModel(nn.Module):
def __init__(self):
super().__init__()
# Initialize your model here
self.linear = nn.Linear(10, 5)
def forward(self, x):
# The cache_key parameter is injected by the decorator
# When provided, results will be cached
print("Not hit the cache, forwarding", x.shape)
return self.linear(x)
# Create model instance
model = MyModel()
print("Normal forward pass (no caching)")
# Normal forward pass (no caching)
input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print("Cached forward pass (first time will compute and cache)")
# Cached forward pass (first time will compute and cache)
output_cached = model(input_tensor, cache_key="my_unique_key")
print("Subsequent calls with the same key will load from cache")
# Subsequent calls with the same key will load from cache
output_from_cache = model(input_tensor, cache_key="my_unique_key")
Batch Processing
The decorator supports batched inference, which can significantly improve performance when processing multiple inputs:
import torch
import torch.nn as nn
from torch_module_cache import cache_module
@cache_module()
class BatchModel(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(10, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
def forward(self, x):
print("Not hit the cache, forwarding", x.shape)
return self.encoder(x)
# Create model instance
model = BatchModel()
print("Begin test")
# Create a batch of inputs
batch_size = 4
batch_input = torch.randn(batch_size, 10)
# Create a list of cache keys (one for each item in the batch)
batch_keys = ["item1", "item2", "item3", "item4"]
print("Process the entire batch with unique keys for each item")
# Process the entire batch with unique keys for each item
# The decorator will handle caching each result individually
batch_output = model(batch_input, cache_key=batch_keys)
print("The next time you use the same keys, results will be loaded from cache")
# The next time you use the same keys, results will be loaded from cache
cached_batch_output = model(batch_input, cache_key=batch_keys)
Partial Cache Hits
One of the key features is the ability to handle partial cache hits efficiently:
# Some keys are already cached, some are new
mixed_keys = ["item1", "item2", "new_item1", "new_item2"]
# Only the new items will be processed, cached items will be loaded from cache
mixed_output = model(batch_input, cache_key=mixed_keys)
Configuration Options
The @cache_module() decorator accepts several configuration parameters:
@cache_module(
# Path to store cache files (default: ~/.cache/torch-module-cache)
cache_path="/path/to/cache",
# Subfolder name for this specific model (default: class name)
cache_name="my_model_cache",
# Cache level: CacheLevel.DISK or CacheLevel.MEMORY
cache_level=CacheLevel.MEMORY,
# Whether to use safer loading options (recommended for untrusted data)
safe_load=True
)
Cache Management
from torch_module_cache import clear_memory_caches, clear_disk_caches
# Clear all in-memory caches
clear_memory_caches()
# Clear all disk caches
clear_disk_caches()
# Clear caches for a specific model
clear_memory_caches(cache_name="my_model_cache")
clear_disk_caches(cache_name="my_model_cache")
Performance Considerations
- Memory vs. Disk Caching: Memory caching is much faster but limited by available RAM
- Batch Processing: Processing inputs in batches is typically much faster than individual processing
- Cache Keys: Choose unique and meaningful cache keys that represent your inputs
- Cache Path: For large models, ensure the cache path has sufficient disk space
Advanced Usage
Custom Cache Path
# Custom cache path
@cache_module(cache_path="/tmp/my_model_cache")
class CustomPathModel(nn.Module):
# ...
Batch Processing with Mixed Types
# The decorator handles various input types
# Cache keys can be strings, numbers, or any hashable types
model_output = model(inputs, cache_key=[1, 2, 3, 4])
Example Scripts
The package includes several example scripts in the examples/ directory:
basic_usage.py: Simple example showing basic caching functionalitybatch_usage.py: Demonstrates batch processing and performance comparison
Notes
- The first forward pass with a specific cache key will always execute the model
- For best performance with batches, try to reuse the same batch size and structure
- Non-tensor inputs and outputs are supported but may have serialization limitations
License
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_module_cache-0.1.1.tar.gz.
File metadata
- Download URL: torch_module_cache-0.1.1.tar.gz
- Upload date:
- Size: 13.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2220e3fdd252387b88be7d5c87da5ec37b447ec6528563bbcca3ec24cfd7de62
|
|
| MD5 |
290bace7f8ad20b9924ba800a1ca1874
|
|
| BLAKE2b-256 |
181aacd790ff23230b1191c460dd6e000dcadfab17b449821116a3ac360218c1
|
File details
Details for the file torch_module_cache-0.1.1-py3-none-any.whl.
File metadata
- Download URL: torch_module_cache-0.1.1-py3-none-any.whl
- Upload date:
- Size: 11.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2c7791dc2566f6c73d5cae366f291a93b9842dc9937e93970e40f37b1f2cf0f0
|
|
| MD5 |
962219a90aedad3401060fe6bd5d3345
|
|
| BLAKE2b-256 |
7281574671aae87122727d929d05e7c77fa6f6e35eed61e31813423f1fa5afe5
|