A collection of assertion methods to compare PyTorch Tensors in tests
Project description
# Torch Testing
A collection of assertion methods to compare PyTorch Tensors in tests.
Currently all assertion methods are provided by converting the tensors to numpy arrays and feeding them into an appropriate `numpy.testing` method. That way, on failure, detailed information is provided as to why the test failed.
Last tested with **Python 3.6.4 :: Anaconda, Inc.** and **PyTorch 0.4**.
## Installation
You can install this package using `pip`:
```py
pip install torch_testing
```
## Usage example
You can assert the equality of two `torch.tensor`s like
```py
import unittest
import torch
import torch_testing as tt
class TestSomeClass(unittest.TestCase):
def test_some_method(self):
a = torch.tensor([1, 2])
b = torch.tensor([1, 2])
tt.assert_equal(a, b)
if __name__ == '__main__':
unittest.main()
```
## Assertion methods
### `assert_equal(actual, expected, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_equal](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_equal.html#numpy.testing.assert_equal).
### `assert_allclose(actual, expected, rtol=1e-07, atol=0, equal_nan=True, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).
### `assert_within(tensor, min_val, max_val, rtol=0)`
Ensures that all values of the given `tensor` are greater than or equal to `min_val` and less than or equal to `max_val`. Allows to specify a relative tolerance `rtol`, which behaves as in [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).
*NOTE: Uses `assert_allclose` under the hood, hence the fail message might currently be a little confusing.*
## Development
*Unless noted otherwise, all commands are expected to be executed from the root directory of this repository.*
### Building the package for local development
To make the package available locally while making sure changes to the files are reflected immediately, run
```sh
pip install -e .
```
### Test suite
Run all tests using
```sh
python -m unittest discover tests
```
A collection of assertion methods to compare PyTorch Tensors in tests.
Currently all assertion methods are provided by converting the tensors to numpy arrays and feeding them into an appropriate `numpy.testing` method. That way, on failure, detailed information is provided as to why the test failed.
Last tested with **Python 3.6.4 :: Anaconda, Inc.** and **PyTorch 0.4**.
## Installation
You can install this package using `pip`:
```py
pip install torch_testing
```
## Usage example
You can assert the equality of two `torch.tensor`s like
```py
import unittest
import torch
import torch_testing as tt
class TestSomeClass(unittest.TestCase):
def test_some_method(self):
a = torch.tensor([1, 2])
b = torch.tensor([1, 2])
tt.assert_equal(a, b)
if __name__ == '__main__':
unittest.main()
```
## Assertion methods
### `assert_equal(actual, expected, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_equal](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_equal.html#numpy.testing.assert_equal).
### `assert_allclose(actual, expected, rtol=1e-07, atol=0, equal_nan=True, **kwargs)`
Currently this assertion method is provided by converting the tensors to `numpy` arrays using `tensor.numpy()` and feeding them to [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).
### `assert_within(tensor, min_val, max_val, rtol=0)`
Ensures that all values of the given `tensor` are greater than or equal to `min_val` and less than or equal to `max_val`. Allows to specify a relative tolerance `rtol`, which behaves as in [numpy.testing.assert_allclose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).
*NOTE: Uses `assert_allclose` under the hood, hence the fail message might currently be a little confusing.*
## Development
*Unless noted otherwise, all commands are expected to be executed from the root directory of this repository.*
### Building the package for local development
To make the package available locally while making sure changes to the files are reflected immediately, run
```sh
pip install -e .
```
### Test suite
Run all tests using
```sh
python -m unittest discover tests
```
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch_testing-0.0.2.tar.gz
(2.5 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_testing-0.0.2.tar.gz.
File metadata
- Download URL: torch_testing-0.0.2.tar.gz
- Upload date:
- Size: 2.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
65fc4232ad4cba84da32e8054eb0fef8ba393633ba61d2be056c6c1d85869b35
|
|
| MD5 |
b88de24af969db4b778a21fa0097948d
|
|
| BLAKE2b-256 |
e5858b5d7b8220fc483b329c81ea842d70535617098617333d112a156b0a0ca0
|
File details
Details for the file torch_testing-0.0.2-py3-none-any.whl.
File metadata
- Download URL: torch_testing-0.0.2-py3-none-any.whl
- Upload date:
- Size: 4.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fba9a8db3b67c61630e6c8f539be6b63af81bd4d8561f8cda6ca4ef2270a5750
|
|
| MD5 |
89dc303c78d3ae13095991fd21736acb
|
|
| BLAKE2b-256 |
cffeb3fa6a9400ef0b534bd6d0058247ec1dc4fbc70a6dc0b81bfd5bcb9c16b3
|