Skip to main content

A collection of assertion methods to compare PyTorch Tensors in tests

Project description

# Torch Testing

A collection of assertion methods to compare PyTorch Tensors in tests.

Currently all assertion methods are provided by converting the tensors to numpy arrays and feeding them into an appropriate `numpy.testing` method. That way, on failure, detailed information is provided as to why the test failed.

Last tested with **Python 3.6.4 :: Anaconda, Inc.** and **PyTorch 0.4**.

## Installation

You can install this package using `pip`:

```py
pip install torch_testing
```

## Usage example

You can assert the equality of two `torch.tensor`s like

```py
import unittest
import torch
import torch_testing as tt


class TestSomeClass(unittest.TestCase):

def test_some_method(self):
a = torch.tensor([1, 2])
b = torch.tensor([1, 2])
tt.assert_equal(a, b)

if __name__ == '__main__':
unittest.main()
```

## Assertion methods

### `assert_equal(actual, expected, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_equal](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_equal.html#numpy.testing.assert_equal).

### `assert_allclose(actual, expected, rtol=1e-07, atol=0, equal_nan=True, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).

### `assert_within(tensor, min_val, max_val, rtol=0)`
Ensures that all values of the given `tensor` are greater than or equal to `min_val` and less than or equal to `max_val`. Allows to specify a relative tolerance `rtol`, which behaves as in [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).

*NOTE: Uses `assert_allclose` under the hood, hence the fail message might currently be a little confusing.*

## Development

*Unless noted otherwise, all commands are expected to be executed from the root directory of this repository.*

### Building the package for local development

To make the package available locally while making sure changes to the files are reflected immediately, run

```sh
pip install -e .
```

### Test suite

Run all tests using

```sh
python -m unittest discover tests
```


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_testing-0.0.2.tar.gz (2.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_testing-0.0.2-py3-none-any.whl (4.7 kB view details)

Uploaded Python 3

File details

Details for the file torch_testing-0.0.2.tar.gz.

File metadata

  • Download URL: torch_testing-0.0.2.tar.gz
  • Upload date:
  • Size: 2.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No

File hashes

Hashes for torch_testing-0.0.2.tar.gz
Algorithm Hash digest
SHA256 65fc4232ad4cba84da32e8054eb0fef8ba393633ba61d2be056c6c1d85869b35
MD5 b88de24af969db4b778a21fa0097948d
BLAKE2b-256 e5858b5d7b8220fc483b329c81ea842d70535617098617333d112a156b0a0ca0

See more details on using hashes here.

File details

Details for the file torch_testing-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_testing-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 fba9a8db3b67c61630e6c8f539be6b63af81bd4d8561f8cda6ca4ef2270a5750
MD5 89dc303c78d3ae13095991fd21736acb
BLAKE2b-256 cffeb3fa6a9400ef0b534bd6d0058247ec1dc4fbc70a6dc0b81bfd5bcb9c16b3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page