A machine learning sanity check toolkit for PyTorch
Project description
torcheck
Torcheck is a machine learning sanity check toolkit for PyTorch.
About
The creation of torcheck is inspired by Chase Roberts' Medium post. The innovation and major benefit is that you no longer need to write additional testing code for your model training. Just add a few lines of code specifying the checks before training, torcheck will then take over and perform the checks simultaneouly while the training happens.
Another benefit is that torcheck allows you to check your model on different levels. Instead of checking the whole model, you can specify checks for a submodule, a linear layer, or even the weight tensor! This enables more customization around the sanity checks.
Installation
pip install torcheck
Torcheck in 5 minutes
OK, suppose you have coded up a standard PyTorch training routine like this:
model = Model()
optimizer = torch.optim.Adam(
model.parameters(),
lr=0.001,
)
# torcheck code goes here
for epoch in range(num_epochs):
for x, y in dataloader:
# calculate loss and backward propagation
By simply adding a few lines of code right before the for loop, you can be more confident about whether your model is training as expected!
Step 1: Registering your optimizer(s)
First, register the optimizer(s) with torcheck:
torcheck.register(optimizer)
Step 2: Adding sanity checks
Torcheck enables you to perform a wide range of checks, on both module level and tensor level.
A rule of thumb is that use APIs with add_module
prefix when checking something that
subclasses from nn.Module
, use APIs with add_tensor
prefix when checking tensors.
Parameters change/not change
You can check whether model parameters actually get updated during the training. Or you can check whether they remain constant if you want them to be frozen.
For our example, some of the possible checks are:
# check all the model parameters will change
# module_name is optional, but it makes error messages more informative when checks fail
torcheck.add_module_changing_check(model, module_name="my_model")
# check the linear layer's parameters won't change
torcheck.add_module_unchanging_check(model.linear_0, module_name="linear_layer_0")
# check the linear layer's weight parameters will change
torcheck.add_tensor_changing_check(
model.linear_0.weight, tensor_name="linear_0.weight", module_name="my_model"
)
# check the linear layer's bias parameters won't change
torcheck.add_tensor_unchanging_check(
model.linear_0.bias, tensor_name="linear_0.bias", module_name="my_model"
)
Output range check
The basic use case is that you can check whether model outputs are all within a range, say (-1, 1).
You can also check that model outputs are not all within a range. This is useful when you want softmax to behave correctly. It enables you to check model ouputs are not all within (0, 1).
You can check the final model output or intermediate output of a submodule.
# check model outputs are within (-1, 1)
torcheck.add_module_output_range_check(
model, output_range=(-1, 1), module_name="my_model"
)
# check outputs from the linear layer are within (-5, 5)
torcheck.add_module_output_range_check(
model.linear_0, output_range=(-5, 5), module_name="linear_layer_0"
)
# check model outputs are not all within (0, 1)
# aka softmax hasn't been applied before loss calculation
torcheck.add_module_output_range_check(
model,
output_range=(0, 1),
negate_range=True,
module_name="my_model",
)
NaN check
Check whether parameters become NaN during training, or model outputs contain NaN.
# check whether model parameters become NaN or outputs contain NaN
torcheck.add_module_nan_check(model, module_name="my_model")
# check whether linear layer's weight parameters become NaN
torcheck.add_tensor_nan_check(
model.linear_0.weight, tensor_name="linear_0.weight", module_name="my_model"
)
Inf check
Check whether parameters become infinite (positive or negative infinity) during training, or model outputs contain infinite value.
# check whether model parameters become infinite or outputs contain infinite value
torcheck.add_module_inf_check(model, module_name="my_model")
# check whether linear layer's weight parameters become infinite
torcheck.add_tensor_inf_check(
model.linear_0.weight, tensor_name="linear_0.weight", module_name="my_model"
)
Adding multiple checks in one call
You can add all checks for a module/tensor in one call:
# add all checks for model together
torcheck.add_module(
model,
module_name="my_model",
changing=True,
output_range=(-1, 1),
check_nan=True,
check_inf=True,
)
# add all checks for linear layer's weight together
torcheck.add_tensor(
model.linear_0.weight,
tensor_name="linear_0.weight",
module_name="my_model",
changing=True,
check_nan=True,
check_inf=True,
)
(Optional) Step 3: Turning off checks
When your model has passed all the checks, you can easily turn them off to get rid of the overhead:
torcheck.disable()
If you want to turn on the checks again, just call
torcheck.enable()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torcheck-1.0.0.tar.gz
.
File metadata
- Download URL: torcheck-1.0.0.tar.gz
- Upload date:
- Size: 8.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.6 CPython/3.8.7 Linux/4.15.0-1077-gcp
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8748be250280341068dac3b14d4c2c7e9289c2fa25f3466e12a55ef03948727a |
|
MD5 | 1a41e508d61a6c5abee60f96f1aa2a4b |
|
BLAKE2b-256 | 1aa073e3a1b2da81bfb581c835e9426255506db1201125377df729552274a634 |
File details
Details for the file torcheck-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: torcheck-1.0.0-py3-none-any.whl
- Upload date:
- Size: 8.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.6 CPython/3.8.7 Linux/4.15.0-1077-gcp
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d6bae38ea1f966c52a739f53e594996452021a00c1957275b78fa1869fbaa731 |
|
MD5 | 861e597deb131fcd26a5ee750a5355b8 |
|
BLAKE2b-256 | 047015ba7a955ca559b28a1318d78cc57169dde2e2a66ed65431d6ca47e8112e |