Cache PyTorch module outputs on the fly
Project description
torchcache
torchcache
caches PyTorch module outputs on the fly. It can cache persistent outputs to disk or in-memory, and can be applied with a simple decorator:
from torchcache import torchcache
@torchcache()
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
# This output will be cached
return self.linear(x)
Installation
pip install torchcache
Assumptions
- The module is a subclass of
nn.Module
- The module forward method is called with any number of positional and arguments with shapes (B, *) where B is the batch size and * is any number of dimensions. The tensors must be on the same device and have the same dtype.
- The forward method returns a single tensor with shape (B, *).
Use case
Caching the outputs of a pre-trained model backbones for faster training, assuming the backbone is frozen and the outputs are not needed for backpropagation. For example, in the following snippet, the outputs of the backbone are cached, but the outputs of the head are not:
import torch
import torch.nn as nn
from torchcache import torchcache
@torchcache(persistent=True)
class MyBackbone(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
self.eval()
self.requires_grad_(False)
def forward(self, x):
# This output will be cached to disk
return self.linear(x)
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = MyBackbone()
self.head = nn.Linear(10, 10)
def forward(self, x):
# This output will be cached
x = self.backbone(x)
# This output will not be cached
x = self.head(x)
return x
model = MyModel()
Environment variables
The following environment variables may be useful to set the package behavior:
TORCHCACHE_LOG_LEVEL
- logging level, defaults toWARN
TORCHCACHE_LOG_FMT
- logging format, defaults to[torchcache] - %(asctime)s - %(name)s - %(levelname)s - %(message)s
TORCHCACHE_LOG_DATEFMT
- logging date format, defaults to%Y-%m-%d %H:%M:%S
TORCHCACHE_LOG_FILE
- path to the log file, defaults toNone
. Opened in append mode.
Contribution
- Install Python.
- Install
poetry
- Run
poetry install
to install dependencies - Run
poetry run pre-commit install
to install pre-commit hooks
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchcache-0.1.0.tar.gz
(9.2 kB
view hashes)
Built Distribution
Close
Hashes for torchcache-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | feb972f5ba9f30319a156d361dd0b922aea0aa69f47ea5bf3df7b38c07f1b5d9 |
|
MD5 | ac930727f4f7a78873b9f423d00a763f |
|
BLAKE2b-256 | 48c39f7eaba19e1fa9250c3a7ac3026383c2e22dd279d6b7807890ef6643e64d |