Skip to main content

Mocks pytorch modules so that test run faster

Project description

torchmocks

Test pytorch code with minimal computational overhead.

Problem

The computational overhead of neural networks discourages thorough testing during development and within CI/CD pipelines.

Solution

Torchmocks replaces common building blocks (such as torch.nn.Conv2d) with replicas that only keep track of tensor shapes and device location. This is often the only information that we need to check to ensure proper function of pytorch code.

Example

import torch
import torchmocks
from torchvision.models import resnet152

def test_mock_resnet():
    net = resnet152()
    torchmocks.mock(net)
    image_batch = torch.zeros(4, 3, 255, 255)
    output = net(image_batch)
    assert output.shape == (4, 1000)

Status

This is a work in progress and only a handful of torch modules have been mocked. Modules that have not been mocked will run their normal computation during the forward and backward pass.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmocks-0.0.1.tar.gz (3.0 kB view hashes)

Uploaded Source

Built Distribution

torchmocks-0.0.1-py3-none-any.whl (3.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page