Skip to main content

Measures and metrics for image2image tasks. PyTorch.

Project description

PyTorch Image Quality

License PyPI version
CI flake-8 style check CI testing
Quality Gate Status Maintainability Rating Reliability Rating

Collection of measures and metrics for automatic image quality assessment in various image-to-image tasks such as denoising, super-resolution, image generation etc. This easy to use yet flexible and extensive library is developed with focus on reliability and reproducibility of results. Use your favourite measures as losses for training neural networks with ready-to-use PyTorch modules.

Getting started

import torch
from piq import ssim

prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
ssim_index = ssim(prediction, target, data_range=1.)

Examples

Peak Signal-to-Noise Ratio (PSNR)

To compute PSNR as a measure, use lower case function from the library. By default it computes average of PSNR if more than 1 image is included in batch. You can specify other reduction methods by reduction flag.

import torch
from piq import psnr
from typing import Union, Tuple

prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256) 
psnr_mean = psnr(prediction, target, data_range=1., reduction='mean')
psnr_per_image = psnr(prediction, target, data_range=1., reduction='none')

Note: Colour images are first converted to YCbCr format and only luminance component is considered.

Structural Similarity (SSIM)

To compute SSIM index as a measure, use lower case function from the library:

import torch
from piq import ssim
from typing import Union, Tuple

prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256) 
ssim_index: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = ssim(prediction, target, data_range=1.)

In order to use SSIM as a loss function, use corresponding PyTorch module:

import torch
from piq import SSIMLoss

loss = SSIMLoss(data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()

Multi-Scale Structural Similarity (MS-SSIM)

To compute MS-SSIM index as a measure, use lower case function from the library:

import torch
from piq import multi_scale_ssim

prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256) 
ms_ssim_index: torch.Tensor = multi_scale_ssim(prediction, target, data_range=1.)

In order to use MS-SSIM as a loss function, use corresponding PyTorch module:

import torch
from piq import MultiScaleSSIMLoss

loss = MultiScaleSSIMLoss(data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()

Total Variation (TV)

To compute TV as a measure, use lower case function from the library:

import torch
from piq import total_variation

data = torch.rand(3, 3, 256, 256) 
tv: torch.Tensor = total_variation(data)

In order to use TV as a loss function, use corresponding PyTorch module:

import torch
from piq import TVLoss

loss = TVLoss()
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
output: torch.Tensor = loss(prediction)
output.backward()

Visual Information Fidelity (VIF)

To compute VIF as a measure, use lower case function from the library:

import torch
from piq import vif_p

predicted = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
vif: torch.Tensor = vif_p(predicted, target, data_range=1.)

In order to use VIF as a loss function, use corresponding PyTorch class:

import torch
from piq import VIFLoss

loss = VIFLoss(sigma_n_sq=2.0, data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()

Note, that VIFLoss returns 1 - VIF value.

Gradient Magnitude Similarity Deviation (GMSD)

This is port of MATLAB version from the authors of original paper. It can be used both as a measure and as a loss function. In any case it should me minimized. Usually values of GMSD lie in [0, 0.35] interval.

import torch
from piq import GMSDLoss

loss = GMSDLoss(data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()

MultiScale GMSD (MS-GMSD)

It can be used both as a measure and as a loss function. In any case it should me minimized. By defualt scale weights are initialized with values from the paper. You can change them by passing a list of 4 variables to scale_weights argument during initialization. Both GMSD and MS-GMSD computed for greyscale images, but to take contrast changes into account authors propoced to also add chromatic component. Use flag chromatic to use MS-GMSDc version of the loss

import torch
from piq import MultiScaleGMSDLoss

loss = MultiScaleGMSDLoss(chromatic=True, data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()

Blind/Referenceless Image Spatial Quality Evaluator (BRISQUE)

To compute BRISQUE score as a measure, use lower case function from the library:

import torch
from piq import brisque
from typing import Union, Tuple

prediction = torch.rand(3, 3, 256, 256)
brisque_index: torch.Tensor = brisque(prediction, data_range=1.)

In order to use BRISQUE as a loss function, use corresponding PyTorch module:

import torch
from piq import BRISQUELoss

loss = BRISQUELoss(data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
output: torch.Tensor = loss(prediction)
output.backward()

Multi-Scale Intrinsic Distance (MSID)

Use MSID class to compute MSID score from image features, pre-extracted from some feature extractor network:

import torch
from piq import MSID

msid_metric = MSID()
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
msid: torch.Tensor = msid_metric(prediction_feats, target_feats)

If image features are not available, extract them using _compute_feats of MSID class. Please note that _compute_feats consumes a data loader of predefined format.

import torch
from  torch.utils.data import DataLoader
from piq import MSID

first_dl, second_dl = DataLoader(), DataLoader()
msid_metric = MSID() 
first_feats = msid_metric._compute_feats(first_dl)
second_feats = msid_metric._compute_feats(second_dl)
msid: torch.Tensor = msid_metric(first_feats, second_feats)

Frechet Inception Distance(FID)

Use FID class to compute FID score from image features, pre-extracted from some feature extractor network:

import torch
from piq import FID

fid_metric = FID()
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
msid: torch.Tensor = fid_metric(prediction_feats, target_feats)

If image features are not available, extract them using _compute_feats of FID class. Please note that _compute_feats consumes a data loader of predefined format.

import torch
from  torch.utils.data import DataLoader
from piq import FID

first_dl, second_dl = DataLoader(), DataLoader()
fid_metric = FID() 
first_feats = fid_metric._compute_feats(first_dl)
second_feats = fid_metric._compute_feats(second_dl)
msid: torch.Tensor = fid_metric(first_feats, second_feats)

Kernel Inception Distance(KID)

Use KID class to compute KID score from image features, pre-extracted from some feature extractor network:

import torch
from piq import KID

kid_metric = KID()
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
kid: torch.Tensor = kid_metric(prediction_feats, target_feats)

If image features are not available, extract them using _compute_feats of KID class. Please note that _compute_feats consumes a data loader of predefined format.

import torch
from  torch.utils.data import DataLoader
from piq import KID

first_dl, second_dl = DataLoader(), DataLoader()
kid_metric = KID() 
first_feats = kid_metric._compute_feats(first_dl)
second_feats = kid_metric._compute_feats(second_dl)
kid: torch.Tensor = kid_metric(first_feats, second_feats)

Geometry Score (GS)

Use GS class to compute Geometry Score from image features, pre-extracted from some feature extractor network. Computation is heavily CPU dependent, adjust num_workers parameter according to your system configuration:

import torch
from piq import GS

gs_metric = GS(sample_size=64, num_iters=100, i_max=100, num_workers=4)
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
gs: torch.Tensor = gs_metric(prediction_feats, target_feats)

GS metric requiers gudhi library which is not installed by default. If you use conda, write: conda install -c conda-forge gudhi, otherwise follow installation guide.

Inception Score(IS)

Use inception_score function to compute IS from image features, pre-extracted from some feature extractor network. Note, that we follow recomendations from paper A Note on the Inception Score, which proposed small modification to original algorithm:

import torch
from piq import inception_score

prediction_feats = torch.rand(10000, 1024)
mean, variance = inception_score(prediction_feats, num_splits=10)

To compute difference between IS for 2 sets of image features, use IS class.

import torch
from piq import IS


is_metric = IS(distance='l1') 
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
distance: torch.Tensor = is_metric(prediction_feats, target_feats)

Table of Contents

Overview

PyTorch Image Quality (former PhotoSynthesis.Metrics) helps you to concentrate on your experiments without the boilerplate code. The library contains a set of measures and metrics that is constantly getting extended. For measures/metrics that can be used as loss functions, corresponding PyTorch modules are implemented.

Installation

$ pip install piq

If you want to use the latest features straight from the master, clone the repo:

$ git clone https://github.com/photosynthesis-team/piq.git

Roadmap

See the open issues for a list of proposed features and known issues.

Community

Contributing

We appreciate all contributions. If you plan to:

  • contribute back bug-fixes, please do so without any further discussion
  • close one of open issues, please do so if no one has been assigned to it
  • contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us

Please see the contribution guide for more information.

Contact

Sergey Kastryulin - @snk4tr - snk4tr@gmail.com

Project Link: https://github.com/photosynthesis-team/piq
PhotoSynthesis Team: https://github.com/photosynthesis-team

Other projects by PhotoSynthesis Team:

Acknowledgements

  • Pavel Parunin - @PavelParunin - idea proposal and development
  • Djamil Zakirov - @zakajd - development
  • Denis Prokopenko - @denproc - development

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

piq-0.4.1.tar.gz (48.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

piq-0.4.1-py3-none-any.whl (60.8 kB view details)

Uploaded Python 3

File details

Details for the file piq-0.4.1.tar.gz.

File metadata

  • Download URL: piq-0.4.1.tar.gz
  • Upload date:
  • Size: 48.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.24.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.8.3

File hashes

Hashes for piq-0.4.1.tar.gz
Algorithm Hash digest
SHA256 0ae2bfb233aef0286f286e8fc15fbdf42deeca22cd002f485f79eb7490dfa3bc
MD5 20ac696818d9452b056341a29dec98eb
BLAKE2b-256 ef789e60dae6ac3e4f156c8040d979189ad3985175115d2e461faf231b6754d7

See more details on using hashes here.

File details

Details for the file piq-0.4.1-py3-none-any.whl.

File metadata

  • Download URL: piq-0.4.1-py3-none-any.whl
  • Upload date:
  • Size: 60.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.24.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.8.3

File hashes

Hashes for piq-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4abce5ff0682d00efeb25b142fa94330fd637521a0e01bc8041875046fef9142
MD5 ba7c4bd00ba93e8ea59bffc34aa1f344
BLAKE2b-256 4a7b0625ceb5fd97c04b34425c9dd5cc0d7ffd6ea50debb4304cf796575eef60

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page