Skip to main content

Cache PyTorch module outputs on the fly

Project description

torchcache

Lint and Test codecov

torchcache caches PyTorch module outputs on the fly. It can cache persistent outputs to disk or in-memory, and can be applied with a simple decorator:

from torchcache import torchcache

@torchcache()
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        # This output will be cached
        return self.linear(x)

Installation

pip install torchcache

Assumptions

  • The module is a subclass of nn.Module
  • The module forward method is called with any number of positional and arguments with shapes (B, *) where B is the batch size and * is any number of dimensions. The tensors must be on the same device and have the same dtype.
  • The forward method returns a single tensor with shape (B, *).

Use case

Caching the outputs of a pre-trained model backbones for faster training, assuming the backbone is frozen and the outputs are not needed for backpropagation. For example, in the following snippet, the outputs of the backbone are cached, but the outputs of the head are not:

import torch
import torch.nn as nn
from torchcache import torchcache


@torchcache(persistent=True)
class MyBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)
        self.eval()
        self.requires_grad_(False)

    def forward(self, x):
        # This output will be cached to disk
        return self.linear(x)

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = MyBackbone()
        self.head = nn.Linear(10, 10)

    def forward(self, x):
        # This output will be cached
        x = self.backbone(x)
        # This output will not be cached
        x = self.head(x)
        return x

model = MyModel()

Environment variables

The following environment variables may be useful to set the package behavior:

  • TORCHCACHE_LOG_LEVEL - logging level, defaults to WARN
  • TORCHCACHE_LOG_FMT - logging format, defaults to [torchcache] - %(asctime)s - %(name)s - %(levelname)s - %(message)s
  • TORCHCACHE_LOG_DATEFMT - logging date format, defaults to %Y-%m-%d %H:%M:%S
  • TORCHCACHE_LOG_FILE - path to the log file, defaults to None. Opened in append mode.

Contribution

  1. Install Python.
  2. Install poetry
  3. Run poetry install to install dependencies
  4. Run poetry run pre-commit install to install pre-commit hooks

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcache-0.1.0.tar.gz (9.2 kB view hashes)

Uploaded Source

Built Distribution

torchcache-0.1.0-py3-none-any.whl (9.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page