A collection of assertion methods to compare PyTorch Tensors in tests
Project description
# Torch Testing
A collection of assertion methods to compare PyTorch Tensors in tests.
Currently all assertion methods are provided by converting the tensors to numpy arrays and feeding them into an appropriate `numpy.testing` method. That way, on failure, detailed information is provided as to why the test failed.
Last tested with **Python 3.6.4 :: Anaconda, Inc.** and **PyTorch 0.4**.
## Installation
You can install this package using `pip`:
```py
pip install torch_testing
```
## Usage example
You can assert the equality of two `torch.tensor`s like
```py
import unittest
import torch
import torch_testing as tt
class TestSomeClass(unittest.TestCase):
def test_some_method(self):
a = torch.tensor([1, 2])
b = torch.tensor([1, 2])
tt.assert_equal(a, b)
if __name__ == '__main__':
unittest.main()
```
## Assertion methods
### `assert_equal(actual, expected, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_equal](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_equal.html#numpy.testing.assert_equal).
### `assert_allclose(actual, expected, rtol=1e-07, atol=0, equal_nan=True, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).
### `assert_within(tensor, min_val, max_val, rtol=0)`
Ensures that all values of the given `tensor` are greater than or equal to `min_val` and less than or equal to `max_val`. Allows to specify a relative tolerance `rtol`, which behaves as in [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).
*NOTE: Uses `assert_allclose` under the hood, hence the fail message might currently be a little confusing.*
## Development
*Unless noted otherwise, all commands are expected to be executed from the root directory of this repository.*
### Building the package for local development
To make the package available locally while making sure changes to the files are reflected immediately, run
```sh
pip install -e .
```
### Test suite
Run all tests using
```sh
python -m unittest discover tests
```
A collection of assertion methods to compare PyTorch Tensors in tests.
Currently all assertion methods are provided by converting the tensors to numpy arrays and feeding them into an appropriate `numpy.testing` method. That way, on failure, detailed information is provided as to why the test failed.
Last tested with **Python 3.6.4 :: Anaconda, Inc.** and **PyTorch 0.4**.
## Installation
You can install this package using `pip`:
```py
pip install torch_testing
```
## Usage example
You can assert the equality of two `torch.tensor`s like
```py
import unittest
import torch
import torch_testing as tt
class TestSomeClass(unittest.TestCase):
def test_some_method(self):
a = torch.tensor([1, 2])
b = torch.tensor([1, 2])
tt.assert_equal(a, b)
if __name__ == '__main__':
unittest.main()
```
## Assertion methods
### `assert_equal(actual, expected, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_equal](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_equal.html#numpy.testing.assert_equal).
### `assert_allclose(actual, expected, rtol=1e-07, atol=0, equal_nan=True, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).
### `assert_within(tensor, min_val, max_val, rtol=0)`
Ensures that all values of the given `tensor` are greater than or equal to `min_val` and less than or equal to `max_val`. Allows to specify a relative tolerance `rtol`, which behaves as in [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).
*NOTE: Uses `assert_allclose` under the hood, hence the fail message might currently be a little confusing.*
## Development
*Unless noted otherwise, all commands are expected to be executed from the root directory of this repository.*
### Building the package for local development
To make the package available locally while making sure changes to the files are reflected immediately, run
```sh
pip install -e .
```
### Test suite
Run all tests using
```sh
python -m unittest discover tests
```
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch_testing-0.0.2.tar.gz
(2.5 kB
view hashes)
Built Distribution
Close
Hashes for torch_testing-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fba9a8db3b67c61630e6c8f539be6b63af81bd4d8561f8cda6ca4ef2270a5750 |
|
MD5 | 89dc303c78d3ae13095991fd21736acb |
|
BLAKE2b-256 | cffeb3fa6a9400ef0b534bd6d0058247ec1dc4fbc70a6dc0b81bfd5bcb9c16b3 |