Skip to main content

A package for caching PyTorch modules

Project description

Torch Module Cache

🚀 One-line code to implement PyTorch feature caching, accelerate training by 30x+!

Torch Module Cache is a simple yet powerful PyTorch tool that enables model feature caching with just one line of code, significantly boosting training and inference speed. Whether it's dataset preprocessing or pretrained model feature caching, it's all made easy.

✨ Key Features

  • 🚀 Minimal Code: Enable caching with just one decorator
  • 📈 Significant Speedup: Real-world tests show 30x+ acceleration per epoch
  • 💻 VRAM Friendly: Model will not be loaded until not hit cache, save your VRAM
  • 🔄 Flexible Caching: Support for both dataset and model feature caching
  • 🎯 Smart Inference: Support for inference mode with global cache disabling
  • 💾 Memory Optimized: Automatic cache memory management to prevent leaks

🚀 Quick Start

1. Installation

pip install torch-module-cache

2. Basic Usage

Simply add the @cache_module() decorator to enable feature caching, this will be extremely effective when extracting features within the model using pre-trained models:

from torch_module_cache import cache_module

# Only need to add one line of code to enable caching
@cache_module()
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 3)
    
    def forward(self, x):
        return self.linear(x)

# Using cache
model = MyModel()
# First run will compute and cache the result
output1 = model(x, cache_key="key1")
# Second run will use the cached result
output2 = model(x, cache_key="key1")

# For batch processing, you can use a list of cache keys:
cache_keys = [f"key_{i}" for i in range(10)]
outputs = model(torch.randn(10, 10), cache_key=cache_keys)

3 Pretrained Model Feature Caching

Accelerate your model by caching features from pretrained models like ViT, ResNet, etc.:

# Only need to add one line of code to enable caching
@cache_module()
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        # Load pretrained ViT
        self.vit = timm.create_model("vit_base_patch16_224", pretrained=True)
        self.vit.eval()  # Set to eval mode

    def forward(self, x):
        # Extract features from ViT
        with torch.no_grad():
            features = self.vit.forward_features(x)
        return features

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # `feature_extractor` is frozen, so we can use cache to speed up
        self.feature_extractor = FeatureExtractor()
        self.classifier = nn.Linear(768, 10)  # ViT-Base features are 768-dim

    def forward(self, x, cache_key=None):
        # Features will be cached automatically
        features = self.feature_extractor(x, cache_key=cache_key)
        return self.classifier(features)

4. Dataset Feature Caching

Still manually extracting features and saving them to .pt files? Use caching in your dataset to accelerate data loading with only one-line code:

@cache_module(cache_name="feature_processor")
class FeatureProcessor(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 256)
    
    def forward(self, x):
        return self.linear(x)

class CachedDataset(Dataset):
    def __init__(self):
        self.processor = FeatureProcessor()
    
    def __getitem__(self, idx):
        raw_data = self.data[idx]
        # Use sample index as cache key, the second epoch will start using cache to speed up.
        processed_data = self.processor(raw_data, cache_key=f"sample_{idx}")
        return processed_data, self.labels[idx]

5. Inference Mode

Disable caching during inference:

from torch_module_cache import enable_inference_mode

# Enable inference mode (disable caching and model will be init when instance is created)
enable_inference_mode()

# Model will compute directly without using cache
model = MyModel()
output = model(x)

📊 Performance Comparison

Scenario Without Cache With Cache Speedup
Dataset Preprocessing 100s 3.2s 31.25x
ViT Feature Extraction 2.10s 0.024s 86.82x

📚 More Examples

Check out the examples directory for more usage examples:

⚙️ Configuration Options

The @cache_module() decorator accepts several configuration parameters:

from torch_module_cache import cache_module, CacheLevel

@cache_module(
    # Path to store cache files (default: ~/.cache/torch-module-cache)
    cache_path="/path/to/cache",
    
    # Subfolder name for this specific model (default: class name)
    cache_name="my_model_cache",
    
    # Cache level: CacheLevel.DISK or CacheLevel.MEMORY
    cache_level=CacheLevel.MEMORY,
    
    # Whether to use safer loading options (recommended for untrusted data)
    safe_load=False,
    
    # Maximum memory usage in MB (default: None)
    max_memory_cache_size_mb=None,
)
class MyModel(nn.Module):
    # ... your model implementation

🔧 Cache Management

Memory Management

from torch_module_cache import clear_memory_caches, clear_disk_caches

# Clear all in-memory caches
clear_memory_caches()

# Clear all disk caches
clear_disk_caches()

# Clear caches for a specific model
clear_memory_caches(cache_name="my_model_cache")
clear_disk_caches(cache_name="my_model_cache")

🤝 Contributing

Issues and Pull Requests are welcome!

📄 License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_module_cache-0.1.4.tar.gz (15.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_module_cache-0.1.4-py3-none-any.whl (13.2 kB view details)

Uploaded Python 3

File details

Details for the file torch_module_cache-0.1.4.tar.gz.

File metadata

  • Download URL: torch_module_cache-0.1.4.tar.gz
  • Upload date:
  • Size: 15.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for torch_module_cache-0.1.4.tar.gz
Algorithm Hash digest
SHA256 797ed9775aac0cd6cb09c1e74850c3a5f1847794baaa6c32cd3985cf11a5e72a
MD5 951e3ca2b8794b48a6a9c41ef14def56
BLAKE2b-256 29f2316a3098d7242df421229387a37f38eca307303945cdaad5be54ad3737cc

See more details on using hashes here.

File details

Details for the file torch_module_cache-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_module_cache-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 9232bb7cd85d89a54fcbe5e06d591501abc704b0ae7abc016e3880867ce313dd
MD5 18dc7ebb9e1028ce8aa840d82de4a272
BLAKE2b-256 103e03cfc8c396e52641a41ccdea5a7d20caac79507d535a99592be4238dd525

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page