Skip to main content

A collection of assertion methods to compare PyTorch Tensors in tests

Project description

# Torch Testing

A collection of assertion methods to compare PyTorch Tensors in tests.

Currently all assertion methods are provided by converting the tensors to numpy arrays and feeding them into an appropriate `numpy.testing` method. That way, on failure, detailed information is provided as to why the test failed.

Last tested with **Python 3.6.4 :: Anaconda, Inc.** and **PyTorch 0.4**.

## Installation

You can install this package using `pip`:

```py
pip install torch_testing
```

## Usage example

You can assert the equality of two `torch.tensor`s like

```py
import unittest
import torch
import torch_testing as tt


class TestSomeClass(unittest.TestCase):

def test_some_method(self):
a = torch.tensor([1, 2])
b = torch.tensor([1, 2])
tt.assert_equal(a, b)

if __name__ == '__main__':
unittest.main()
```

## Assertion methods

### `assert_equal(actual, expected, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_equal](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_equal.html#numpy.testing.assert_equal).

### `assert_allclose(actual, expected, rtol=1e-07, atol=0, equal_nan=True, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).

### `assert_within(tensor, min_val, max_val, rtol=0)`
Ensures that all values of the given `tensor` are greater than or equal to `min_val` and less than or equal to `max_val`. Allows to specify a relative tolerance `rtol`, which behaves as in [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).

*NOTE: Uses `assert_allclose` under the hood, hence the fail message might currently be a little confusing.*

## Development

*Unless noted otherwise, all commands are expected to be executed from the root directory of this repository.*

### Building the package for local development

To make the package available locally while making sure changes to the files are reflected immediately, run

```sh
pip install -e .
```

### Test suite

Run all tests using

```sh
python -m unittest discover tests
```


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for torch-testing, version 0.0.2
Filename, size File type Python version Upload date Hashes
Filename, size torch_testing-0.0.2-py3-none-any.whl (4.7 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size torch_testing-0.0.2.tar.gz (2.5 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page