Skip to main content

Cache PyTorch module outputs on the fly

Project description

torchcache

Lint and Test Codecov Documentation Status

Effortlessly cache PyTorch module outputs on-the-fly with torchcache.

Particularly useful for caching and serving the outputs of computationally expensive large, pre-trained PyTorch modules, such as vision transformers. Note that gradients will not flow through the cached outputs.

Features

  • Cache PyTorch module outputs either in-memory or persistently to disk.
  • Simple decorator-based interface for easy usage.
  • Uses an MRU (most-recently-used) cache to limit memory/disk usage

Installation

pip install torchcache

Basic usage

Quickly cache the output of your PyTorch module with a single decorator:

from torchcache import torchcache

@torchcache()
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        # This output will be cached
        return self.linear(x)

    input_tensor = torch.ones(10, dtype=torch.float32)
    # Output is cached during the first call...
    output = model(input_tensor)
    # ...and is retrieved from the cache for the next one
    output_cached = model(input_tensor)

See documentation at torchcache.readthedocs.io for more examples.

Assumptions

To ensure seamless operation, torchcache assumes the following:

  • Your module is a subclass of nn.Module.
  • The module's forward method accepts any number of positional arguments with shapes (B, *), where B is the batch size and * represents any number of dimensions. All tensors should be on the same device and have the same dtype.
  • The forward method returns a single tensor of shape (B, *).

Contribution

  1. Ensure you have Python installed.
  2. Install poetry.
  3. Run poetry install to set up dependencies.
  4. Run poetry run pre-commit install to install pre-commit hooks.
  5. Create a branch, make your changes, and open a pull request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcache-0.4.1.tar.gz (11.1 kB view hashes)

Uploaded Source

Built Distribution

torchcache-0.4.1-py3-none-any.whl (10.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page