Skip to main content

A machine learning sanity check toolkit for PyTorch

Project description

torcheck

Build Status License codecov

Torcheck is a machine learning sanity check toolkit for PyTorch.

About

The creation of torcheck is inspired by Chase Roberts' Medium post. The innovation and major benefit is that you no longer need to write additional testing code for your model training. Just add a few lines of code specifying the checks before training, torcheck will then take over and perform the checks simultaneouly while the training happens.

Another benefit is that torcheck allows you to check your model on different levels. Instead of checking the whole model, you can specify checks for a submodule, a linear layer, or even the weight tensor! This enables more customization around the sanity checks.

Installation

pip install torcheck

Torcheck in 5 minutes

OK, suppose you have coded up a standard PyTorch training routine like this:

model = Model()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.001,
)

# torcheck code goes here

for epoch in range(num_epochs):
    for x, y in dataloader:
        # calculate loss and backward propagation

By simply adding a few lines of code right before the for loop, you can be more confident about whether your model is training as expected!

Step 1: Registering your optimizer(s)

First, register the optimizer(s) with torcheck:

torcheck.register(optimizer)

Step 2: Adding sanity checks

Torcheck enables you to perform a wide range of checks, on both module level and tensor level.

A rule of thumb is that use APIs with add_module prefix when checking something that subclasses from nn.Module, use APIs with add_tensor prefix when checking tensors.

Parameters change/not change

You can check whether model parameters actually get updated during the training. Or you can check whether they remain constant if you want them to be frozen.

For our example, some of the possible checks are:

# check all the model parameters will change
# module_name is optional, but it makes error messages more informative when checks fail
torcheck.add_module_changing_check(model, module_name="my_model")
# check the linear layer's parameters won't change
torcheck.add_module_unchanging_check(model.linear_0, module_name="linear_layer_0")
# check the linear layer's weight parameters will change
torcheck.add_tensor_changing_check(
    model.linear_0.weight, tensor_name="linear_0.weight", module_name="my_model"
)
# check the linear layer's bias parameters won't change
torcheck.add_tensor_unchanging_check(
    model.linear_0.bias, tensor_name="linear_0.bias", module_name="my_model"
)

Output range check

The basic use case is that you can check whether model outputs are all within a range, say (-1, 1).

You can also check that model outputs are not all within a range. This is useful when you want softmax to behave correctly. It enables you to check model ouputs are not all within (0, 1).

You can check the final model output or intermediate output of a submodule.

# check model outputs are within (-1, 1)
torcheck.add_module_output_range_check(
    model, output_range=(-1, 1), module_name="my_model"
)
# check outputs from the linear layer are within (-5, 5)
torcheck.add_module_output_range_check(
    model.linear_0, output_range=(-5, 5), module_name="linear_layer_0"
)

# check model outputs are not all within (0, 1)
# aka softmax hasn't been applied before loss calculation
torcheck.add_module_output_range_check(
    model,
    output_range=(0, 1),
    negate_range=True,
    module_name="my_model",
)

NaN check

Check whether parameters become NaN during training, or model outputs contain NaN.

# check whether model parameters become NaN or outputs contain NaN
torcheck.add_module_nan_check(model, module_name="my_model")
# check whether linear layer's weight parameters become NaN
torcheck.add_tensor_nan_check(
    model.linear_0.weight, tensor_name="linear_0.weight", module_name="my_model"
)

Inf check

Check whether parameters become infinite (positive or negative infinity) during training, or model outputs contain infinite value.

# check whether model parameters become infinite or outputs contain infinite value
torcheck.add_module_inf_check(model, module_name="my_model")
# check whether linear layer's weight parameters become infinite
torcheck.add_tensor_inf_check(
    model.linear_0.weight, tensor_name="linear_0.weight", module_name="my_model"
)

Adding multiple checks in one call

You can add all checks for a module/tensor in one call:

# add all checks for model together
torcheck.add_module(
    model,
    module_name="my_model",
    changing=True,
    output_range=(-1, 1),
    check_nan=True,
    check_inf=True,
)
# add all checks for linear layer's weight together
torcheck.add_tensor(
    model.linear_0.weight,
    tensor_name="linear_0.weight",
    module_name="my_model",
    changing=True,
    check_nan=True,
    check_inf=True,
)

(Optional) Step 3: Turning off checks

When your model has passed all the checks, you can easily turn them off to get rid of the overhead:

torcheck.disable()

If you want to turn on the checks again, just call

torcheck.enable()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torcheck-1.0.0.tar.gz (8.9 kB view details)

Uploaded Source

Built Distribution

torcheck-1.0.0-py3-none-any.whl (8.2 kB view details)

Uploaded Python 3

File details

Details for the file torcheck-1.0.0.tar.gz.

File metadata

  • Download URL: torcheck-1.0.0.tar.gz
  • Upload date:
  • Size: 8.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.6 CPython/3.8.7 Linux/4.15.0-1077-gcp

File hashes

Hashes for torcheck-1.0.0.tar.gz
Algorithm Hash digest
SHA256 8748be250280341068dac3b14d4c2c7e9289c2fa25f3466e12a55ef03948727a
MD5 1a41e508d61a6c5abee60f96f1aa2a4b
BLAKE2b-256 1aa073e3a1b2da81bfb581c835e9426255506db1201125377df729552274a634

See more details on using hashes here.

File details

Details for the file torcheck-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: torcheck-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 8.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.6 CPython/3.8.7 Linux/4.15.0-1077-gcp

File hashes

Hashes for torcheck-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d6bae38ea1f966c52a739f53e594996452021a00c1957275b78fa1869fbaa731
MD5 861e597deb131fcd26a5ee750a5355b8
BLAKE2b-256 047015ba7a955ca559b28a1318d78cc57169dde2e2a66ed65431d6ca47e8112e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page