Skip to main content

Mocks pytorch modules so that test run faster

Project description

torchmocks

Test pytorch code with minimal computational overhead.

Problem

The computational overhead of neural networks discourages thorough testing during development and within CI/CD pipelines.

Solution

Torchmocks replaces common building blocks (such as torch.nn.Conv2d) with replicas that only keep track of tensor shapes and device location. This is often the only information that we need to check to ensure proper function of pytorch code.

Install

pip install torchmocks

Example

import torch
import torchmocks
from torchvision.models import resnet152

def test_mock_resnet():
    net = resnet152()
    torchmocks.mock(net)
    image_batch = torch.zeros(4, 3, 255, 255)
    output = net(image_batch)
    assert output.shape == (4, 1000)

Pytorch Lightning Users

You can exercise most of your training code with torchmocks and the run_fast_dev option for Trainer. See full example here.

def test_training():
    dataset = MockDataset()
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=3)
    val_loader = torch.utils.data.DataLoader(dataset, batch_size=3)
    model = ExamplePytorchLightningModule()
    mock(model, debug=True)
    trainer = pytorch_lightning.Trainer(fast_dev_run=2)
    trainer.fit(model, train_loader, val_loader)

Status

This is a work in progress and only a handful of torch modules have been mocked. Modules that have not been mocked will run their normal computation during the forward pass. I'm also exploring other ways to do shape inference in order to mock operations that don't appear in the module tree. Let me know if you have any ideas.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmocks-0.1.0.tar.gz (7.5 kB view details)

Uploaded Source

Built Distribution

torchmocks-0.1.0-py3-none-any.whl (8.5 kB view details)

Uploaded Python 3

File details

Details for the file torchmocks-0.1.0.tar.gz.

File metadata

  • Download URL: torchmocks-0.1.0.tar.gz
  • Upload date:
  • Size: 7.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for torchmocks-0.1.0.tar.gz
Algorithm Hash digest
SHA256 68a2938d2fc98828bb5a97535d44dfbdcb04973a30c539e98f270cc3c1d861db
MD5 5fcdc5fafee84258d86588b59fac5f00
BLAKE2b-256 9247b9ecb04c70b292b461ad03b74bc93e0d4331e49bfabeb74032833eb33407

See more details on using hashes here.

File details

Details for the file torchmocks-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torchmocks-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for torchmocks-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a08ee08fdaad394f426780c28d64a871b5e79733daab150ca1e411b5978e0128
MD5 2d558f7e5cb228796d4ba3c77d839ba4
BLAKE2b-256 09fd899a8c29075e2d8dcd450349eb975299c52a4e42b9430a07b9db394f85b1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page