Skip to main content

A package for caching PyTorch modules

Project description

Torch Module Cache

A PyTorch module caching decorator that enables efficient caching of module outputs, with support for both single inference and batch processing.

Features

  • Cache PyTorch module outputs to disk or memory
  • Support for both single inputs and batched inputs
  • Automatic smart batching for performance optimization
  • Safe loading options for improved security
  • Memory cache for ultra-fast repeated access
  • Configurable cache paths and naming

Installation

# 1. Recommend using pip install
pip install torch-module-cache

# 2. Or clone the repository
git clone https://github.com/yourusername/torch-module-cache.git
cd torch-module-cache
pip install -e .

Basic Usage

Simple Example

import torch
import torch.nn as nn
from torch_module_cache import cache_module

@cache_module()
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize your model here
        self.linear = nn.Linear(10, 5)
        
    def forward(self, x):
        # The cache_key parameter is injected by the decorator
        # When provided, results will be cached
        print("Not hit the cache, forwarding", x.shape)
        return self.linear(x)

# Create model instance
model = MyModel()

print("Normal forward pass (no caching)")

# Normal forward pass (no caching)
input_tensor = torch.randn(1, 10)
output = model(input_tensor)

print("Cached forward pass (first time will compute and cache)")
# Cached forward pass (first time will compute and cache)
output_cached = model(input_tensor, cache_key="my_unique_key")

print("Subsequent calls with the same key will load from cache")
# Subsequent calls with the same key will load from cache
output_from_cache = model(input_tensor, cache_key="my_unique_key")

Batch Processing

The decorator supports batched inference, which can significantly improve performance when processing multiple inputs:

import torch
import torch.nn as nn
from torch_module_cache import cache_module

@cache_module()
class BatchModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(10, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        
    def forward(self, x):
        print("Not hit the cache, forwarding", x.shape)
        return self.encoder(x)

# Create model instance
model = BatchModel()

print("Begin test")

# Create a batch of inputs
batch_size = 4
batch_input = torch.randn(batch_size, 10)

# Create a list of cache keys (one for each item in the batch)
batch_keys = ["item1", "item2", "item3", "item4"]

print("Process the entire batch with unique keys for each item")
# Process the entire batch with unique keys for each item
# The decorator will handle caching each result individually
batch_output = model(batch_input, cache_key=batch_keys)

print("The next time you use the same keys, results will be loaded from cache")
# The next time you use the same keys, results will be loaded from cache
cached_batch_output = model(batch_input, cache_key=batch_keys)

Partial Cache Hits

One of the key features is the ability to handle partial cache hits efficiently:

# Some keys are already cached, some are new
mixed_keys = ["item1", "item2", "new_item1", "new_item2"]

# Only the new items will be processed, cached items will be loaded from cache
mixed_output = model(batch_input, cache_key=mixed_keys)

Configuration Options

The @cache_module() decorator accepts several configuration parameters:

@cache_module(
    # Path to store cache files (default: ~/.cache/torch-module-cache)
    cache_path="/path/to/cache",
    
    # Subfolder name for this specific model (default: class name)
    cache_name="my_model_cache",
    
    # Cache level: CacheLevel.DISK or CacheLevel.MEMORY
    cache_level=CacheLevel.MEMORY,
    
    # Whether to use safer loading options (recommended for untrusted data)
    safe_load=True
)

Cache Management

from torch_module_cache import clear_memory_caches, clear_disk_caches

# Clear all in-memory caches
clear_memory_caches()

# Clear all disk caches
clear_disk_caches()

# Clear caches for a specific model
clear_memory_caches(cache_name="my_model_cache")
clear_disk_caches(cache_name="my_model_cache")

Performance Considerations

  • Memory vs. Disk Caching: Memory caching is much faster but limited by available RAM
  • Batch Processing: Processing inputs in batches is typically much faster than individual processing
  • Cache Keys: Choose unique and meaningful cache keys that represent your inputs
  • Cache Path: For large models, ensure the cache path has sufficient disk space

Advanced Usage

Custom Cache Path

# Custom cache path
@cache_module(cache_path="/tmp/my_model_cache")
class CustomPathModel(nn.Module):
    # ...

Batch Processing with Mixed Types

# The decorator handles various input types
# Cache keys can be strings, numbers, or any hashable types
model_output = model(inputs, cache_key=[1, 2, 3, 4])

Example Scripts

The package includes several example scripts in the examples/ directory:

  • basic_usage.py: Simple example showing basic caching functionality
  • batch_usage.py: Demonstrates batch processing and performance comparison

Notes

  • The first forward pass with a specific cache key will always execute the model
  • For best performance with batches, try to reuse the same batch size and structure
  • Non-tensor inputs and outputs are supported but may have serialization limitations

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_module_cache-0.1.3.tar.gz (14.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_module_cache-0.1.3-py3-none-any.whl (12.8 kB view details)

Uploaded Python 3

File details

Details for the file torch_module_cache-0.1.3.tar.gz.

File metadata

  • Download URL: torch_module_cache-0.1.3.tar.gz
  • Upload date:
  • Size: 14.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for torch_module_cache-0.1.3.tar.gz
Algorithm Hash digest
SHA256 860c07e479cb171048686e2143c6c848954ba2fe738db3a6f1c5f1923446b5cc
MD5 cb499d699b1b009516784c002e4be6b2
BLAKE2b-256 0d00e20f1823e067fcef70bcc8db9999f094134fccd4896ff31696cc425dd3e9

See more details on using hashes here.

File details

Details for the file torch_module_cache-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_module_cache-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 64da66c3becedeb2d64daaa6d7784dcc17b5fa6dafd810291c39276b0bb7d127
MD5 9269cd57a23c4b698861bb4d4c962031
BLAKE2b-256 3668bec6a992a7b45a2ce1059877bcac4d9e81b2c2eb4c01501f7242ecd96a28

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page