Skip to main content

A package for caching PyTorch modules

Project description

Torch Module Cache

A PyTorch module caching decorator that enables efficient caching of module outputs, with support for both single inference and batch processing.

Features

  • Cache PyTorch module outputs to disk or memory
  • Support for both single inputs and batched inputs
  • Automatic smart batching for performance optimization
  • Safe loading options for improved security
  • Memory cache for ultra-fast repeated access
  • Configurable cache paths and naming

Installation

# 1. Recommend using pip install
pip install torch-module-cache

# 2. Or clone the repository
git clone https://github.com/yourusername/torch-module-cache.git
cd torch-module-cache
pip install -e .

Basic Usage

Simple Example

import torch
import torch.nn as nn
from torch_module_cache import cache_module

@cache_module()
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize your model here
        self.linear = nn.Linear(10, 5)
        
    def forward(self, x):
        # The cache_key parameter is injected by the decorator
        # When provided, results will be cached
        print("Not hit the cache, forwarding", x.shape)
        return self.linear(x)

# Create model instance
model = MyModel()

print("Normal forward pass (no caching)")

# Normal forward pass (no caching)
input_tensor = torch.randn(1, 10)
output = model(input_tensor)

print("Cached forward pass (first time will compute and cache)")
# Cached forward pass (first time will compute and cache)
output_cached = model(input_tensor, cache_key="my_unique_key")

print("Subsequent calls with the same key will load from cache")
# Subsequent calls with the same key will load from cache
output_from_cache = model(input_tensor, cache_key="my_unique_key")

Batch Processing

The decorator supports batched inference, which can significantly improve performance when processing multiple inputs:

import torch
import torch.nn as nn
from torch_module_cache import cache_module

@cache_module()
class BatchModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(10, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        
    def forward(self, x):
        print("Not hit the cache, forwarding", x.shape)
        return self.encoder(x)

# Create model instance
model = BatchModel()

print("Begin test")

# Create a batch of inputs
batch_size = 4
batch_input = torch.randn(batch_size, 10)

# Create a list of cache keys (one for each item in the batch)
batch_keys = ["item1", "item2", "item3", "item4"]

print("Process the entire batch with unique keys for each item")
# Process the entire batch with unique keys for each item
# The decorator will handle caching each result individually
batch_output = model(batch_input, cache_key=batch_keys)

print("The next time you use the same keys, results will be loaded from cache")
# The next time you use the same keys, results will be loaded from cache
cached_batch_output = model(batch_input, cache_key=batch_keys)

Partial Cache Hits

One of the key features is the ability to handle partial cache hits efficiently:

# Some keys are already cached, some are new
mixed_keys = ["item1", "item2", "new_item1", "new_item2"]

# Only the new items will be processed, cached items will be loaded from cache
mixed_output = model(batch_input, cache_key=mixed_keys)

Configuration Options

The @cache_module() decorator accepts several configuration parameters:

@cache_module(
    # Path to store cache files (default: ~/.cache/torch-module-cache)
    cache_path="/path/to/cache",
    
    # Subfolder name for this specific model (default: class name)
    cache_name="my_model_cache",
    
    # Cache level: CacheLevel.DISK or CacheLevel.MEMORY
    cache_level=CacheLevel.MEMORY,
    
    # Whether to use safer loading options (recommended for untrusted data)
    safe_load=True
)

Cache Management

from torch_module_cache import clear_memory_caches, clear_disk_caches

# Clear all in-memory caches
clear_memory_caches()

# Clear all disk caches
clear_disk_caches()

# Clear caches for a specific model
clear_memory_caches(cache_name="my_model_cache")
clear_disk_caches(cache_name="my_model_cache")

Performance Considerations

  • Memory vs. Disk Caching: Memory caching is much faster but limited by available RAM
  • Batch Processing: Processing inputs in batches is typically much faster than individual processing
  • Cache Keys: Choose unique and meaningful cache keys that represent your inputs
  • Cache Path: For large models, ensure the cache path has sufficient disk space

Advanced Usage

Custom Cache Path

# Custom cache path
@cache_module(cache_path="/tmp/my_model_cache")
class CustomPathModel(nn.Module):
    # ...

Batch Processing with Mixed Types

# The decorator handles various input types
# Cache keys can be strings, numbers, or any hashable types
model_output = model(inputs, cache_key=[1, 2, 3, 4])

Example Scripts

The package includes several example scripts in the examples/ directory:

  • basic_usage.py: Simple example showing basic caching functionality
  • batch_usage.py: Demonstrates batch processing and performance comparison

Notes

  • The first forward pass with a specific cache key will always execute the model
  • For best performance with batches, try to reuse the same batch size and structure
  • Non-tensor inputs and outputs are supported but may have serialization limitations

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_module_cache-0.1.2.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_module_cache-0.1.2-py3-none-any.whl (12.5 kB view details)

Uploaded Python 3

File details

Details for the file torch_module_cache-0.1.2.tar.gz.

File metadata

  • Download URL: torch_module_cache-0.1.2.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for torch_module_cache-0.1.2.tar.gz
Algorithm Hash digest
SHA256 5fb835219057586c37fdcaf71c19aa9fe7f9547c3cd4293f3bfe56b9ed726c66
MD5 39bd580065e4444c10be128c95d75bd4
BLAKE2b-256 41a4ae86d0cc9388836febbb76e1170b373a66bb1c8f0ff89d8357820c512a9f

See more details on using hashes here.

File details

Details for the file torch_module_cache-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_module_cache-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7a5cebc85903095e23518e19264e57ef83d1adab0df431fb77533ce5be491b90
MD5 4058fea83e9cffbc1e10331a92c14ed6
BLAKE2b-256 e52e41344cc2df4edf725446278b1e5a88152c29712260b36ac8aed4882cd3cf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page