Measures and metrics for image2image tasks. PyTorch.
Project description
PhotoSynthesis.Metrics
PyTorch library with measures and metrics for various imagetoimage tasks like denoising, superresolution, image generation etc. This easy to use yet flexible and extensive library is developed with focus on reliability and reproducibility of results. Use your favourite measures as losses for training neural networks with readytouse PyTorch modules.
Getting started
import torch from photosynthesis_metrics import ssim prediction = torch.rand(3, 3, 256, 256) target = torch.rand(3, 3, 256, 256) ssim_index = ssim(prediction, target, data_range=1.)
Examples
To compute SSIM index as a measure, use lower case function from the library:
import torch from photosynthesis_metrics import ssim from typing import Union, Tuple prediction = torch.rand(3, 3, 256, 256) target = torch.rand(3, 3, 256, 256) ssim_index: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = ssim(prediction, target, data_range=1.)
In order to use SSIM as a loss function, use corresponding PyTorch module:
import torch from photosynthesis_metrics import SSIMLoss loss = SSIMLoss(data_range=1.) prediction = torch.rand(3, 3, 256, 256, requires_grad=True) target = torch.rand(3, 3, 256, 256) output: torch.Tensor = loss(prediction, target) output.backward()
To compute MSSSIM index as a measure, use lower case function from the library:
import torch from photosynthesis_metrics import multi_scale_ssim prediction = torch.rand(3, 3, 256, 256) target = torch.rand(3, 3, 256, 256) ms_ssim_index: torch.Tensor = multi_scale_ssim(prediction, target, data_range=1.)
In order to use MSSSIM as a loss function, use corresponding PyTorch module:
import torch from photosynthesis_metrics import MultiScaleSSIMLoss loss = MultiScaleSSIMLoss(data_range=1.) prediction = torch.rand(3, 3, 256, 256, requires_grad=True) target = torch.rand(3, 3, 256, 256) output: torch.Tensor = loss(prediction, target) output.backward()
To compute TV as a measure, use lower case function from the library:
import torch from photosynthesis_metrics import total_variation data = torch.rand(3, 3, 256, 256) tv: torch.Tensor = total_variation(data)
In order to use TV as a loss function, use corresponding PyTorch module:
import torch from photosynthesis_metrics import TVLoss loss = TVLoss() prediction = torch.rand(3, 3, 256, 256, requires_grad=True) output: torch.Tensor = loss(prediction) output.backward()
To compute VIF as a measure, use lower case function from the library:
import torch from photosynthesis_metrics import vif_p predicted = torch.rand(3, 3, 256, 256) target = torch.rand(3, 3, 256, 256) vif: torch.Tensor = vif_p(predicted, target, data_range=1.)
In order to use VIF as a loss function, use corresponding PyTorch class:
import torch from photosynthesis_metrics import VIFLoss loss = VIFLoss(sigma_n_sq=2.0, data_range=1.) prediction = torch.rand(3, 3, 256, 256, requires_grad=True) target = torch.rand(3, 3, 256, 256) ouput: torch.Tensor = loss(prediction, target) output.backward()
Note, that VIFLoss returns 1  VIF
value.
This is port of MATLAB version from the authors of original paper. It can be used both as a measure and as a loss function. In any case it should me minimized. Usually values of GMSD lie in [0, 0.35] interval.
import torch from photosynthesis_metrics import GMSDLoss loss = GMSDLoss(data_range=1.) prediction = torch.rand(3, 3, 256, 256, requires_grad=True) target = torch.rand(3, 3, 256, 256) ouput: torch.Tensor = loss(prediction, target) output.backward()
It can be used both as a measure and as a loss function. In any case it should me minimized.
By defualt scale weights are initialized with values from the paper. You can change them by passing a list of 4 variables to scale_weights
argument during initialization. Both GMSD and MSGMSD computed for greyscale images, but to take contrast changes into account authors propoced to also add chromatic component. Use flag chromatic
to use MSGMSDc version of the loss
import torch from photosynthesis_metrics import MultiScaleGMSDLoss loss = MultiScaleGMSDLoss(chromatic=True, data_range=1.) prediction = torch.rand(3, 3, 256, 256, requires_grad=True) target = torch.rand(3, 3, 256, 256) ouput: torch.Tensor = loss(prediction, target) output.backward()
To compute BRISQUE score as a measure, use lower case function from the library:
import torch from photosynthesis_metrics import brisque from typing import Union, Tuple prediction = torch.rand(3, 3, 256, 256) brisque_index: torch.Tensor = brisque(prediction, data_range=1.)
In order to use BRISQUE as a loss function, use corresponding PyTorch module:
import torch from photosynthesis_metrics import BRISQUELoss loss = BRISQUELoss(data_range=1.) prediction = torch.rand(3, 3, 256, 256, requires_grad=True) output: torch.Tensor = loss(prediction) output.backward()
Use MSID
class to compute MSID score from image features,
preextracted from some feature extractor network:
import torch from photosynthesis_metrics import MSID msid_metric = MSID() prediction_feats = torch.rand(10000, 1024) target_feats = torch.rand(10000, 1024) msid: torch.Tensor = msid_metric(prediction_feats, target_feats)
If image features are not available, extract them using _compute_feats
of MSID
class.
Please note that _compute_feats
consumes a data loader of predefined format.
import torch from torch.utils.data import DataLoader from photosynthesis_metrics import MSID first_dl, second_dl = DataLoader(), DataLoader() msid_metric = MSID() first_feats = msid_metric._compute_feats(first_dl) second_feats = msid_metric._compute_feats(second_dl) msid: torch.Tensor = msid_metric(first_feats, second_feats)
Use FID
class to compute FID score from image features,
preextracted from some feature extractor network:
import torch from photosynthesis_metrics import FID fid_metric = FID() prediction_feats = torch.rand(10000, 1024) target_feats = torch.rand(10000, 1024) msid: torch.Tensor = fid_metric(prediction_feats, target_feats)
If image features are not available, extract them using _compute_feats
of FID
class.
Please note that _compute_feats
consumes a data loader of predefined format.
import torch from torch.utils.data import DataLoader from photosynthesis_metrics import FID first_dl, second_dl = DataLoader(), DataLoader() fid_metric = FID() first_feats = fid_metric._compute_feats(first_dl) second_feats = fid_metric._compute_feats(second_dl) msid: torch.Tensor = fid_metric(first_feats, second_feats)
Use KID
class to compute KID score from image features,
preextracted from some feature extractor network:
import torch from photosynthesis_metrics import KID kid_metric = KID() prediction_feats = torch.rand(10000, 1024) target_feats = torch.rand(10000, 1024) kid: torch.Tensor = kid_metric(prediction_feats, target_feats)
If image features are not available, extract them using _compute_feats
of KID
class.
Please note that _compute_feats
consumes a data loader of predefined format.
import torch from torch.utils.data import DataLoader from photosynthesis_metrics import KID first_dl, second_dl = DataLoader(), DataLoader() kid_metric = KID() first_feats = kid_metric._compute_feats(first_dl) second_feats = kid_metric._compute_feats(second_dl) kid: torch.Tensor = kid_metric(first_feats, second_feats)
Use GS
class to compute Geometry Score from image features,
preextracted from some feature extractor network. Computation is heavily CPU dependent, adjust num_workers
parameter according to your system configuration:
import torch from photosynthesis_metrics import GS gs_metric = GS(sample_size=64, num_iters=100, i_max=100, num_workers=4) prediction_feats = torch.rand(10000, 1024) target_feats = torch.rand(10000, 1024) gs: torch.Tensor = gs_metric(prediction_feats, target_feats)
GS metric requiers gudhi
library which is not installed by default.
If you use conda, write: conda install c condaforge gudhi
, otherwise follow installation guide.
Use inception_score
function to compute IS from image features,
preextracted from some feature extractor network. Note, that we follow recomendations from paper A Note on the Inception Score, which proposed small modification to original algorithm:
import torch from photosynthesis_metrics import inception_score prediction_feats = torch.rand(10000, 1024) mean: torch.Tensor, variance: torch.Tensor = inception_score(prediction_feats, num_splits=10)
To compute difference between IS for 2 sets of image features, use IS
class.
import torch from photosynthesis_metrics import IS is_metric = IS(distance='l1') prediction_feats = torch.rand(10000, 1024) target_feats = torch.rand(10000, 1024) distance: torch.Tensor = is_metric(prediction_feats, target_feats)
Table of Contents
Overview
PhotoSynthesis.Metrics helps you to concentrate on your experiments without the boilerplate code. The library contains a set of measures and metrics that is constantly getting extended. For measures/metrics that can be used as loss functions, corresponding PyTorch modules are implemented.
Installation
$ pip install photosynthesismetrics
If you want to use the latest features straight from the master, clone the repo:
$ git clone https://github.com/photosynthesisteam/photosynthesis.metrics.git
Roadmap
See the open issues for a list of proposed features and known issues.
Community
Contributing
We appreciate all contributions. If you plan to:
 contribute back bugfixes, please do so without any further discussion
 close one of open issues, please do so if no one has been assigned to it
 contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us
Please see the contribution guide for more information.
Contact
Sergey Kastryulin  @snk4tr  snk4tr@gmail.com
Project Link: https://github.com/photosynthesisteam/photosynthesis.metrics
PhotoSynthesis Team: https://github.com/photosynthesisteam
Other projects by PhotoSynthesis Team:
Acknowledgements
 Pavel Parunin  @PavelParunin  idea proposal and development
 Djamil Zakirov  @zakajd  development
 Denis Prokopenko  @denproc  development
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Filename, size  File type  Python version  Upload date  Hashes 

Filename, size photosynthesis_metrics0.4.0py3noneany.whl (58.7 kB)  File type Wheel  Python version py3  Upload date  Hashes View 
Filename, size photosynthesis_metrics0.4.0.tar.gz (45.4 kB)  File type Source  Python version None  Upload date  Hashes View 
Hashes for photosynthesis_metrics0.4.0py3noneany.whl
Algorithm  Hash digest  

SHA256  842944095fd83b1a89e7c458f1d25a623a2ff8729c60cbc0ac4b58d6e9ef0ebb 

MD5  a5221a638ea85b290f7bdfe975cd53b5 

BLAKE2256  abf814023af42786b2c1ac946fe163096637f993b7a94f9ba5277d12e48a6140 
Hashes for photosynthesis_metrics0.4.0.tar.gz
Algorithm  Hash digest  

SHA256  7d319ac91f144b12f6f9c9418bb3ac250cef1a7c7796c5f4f68f3c61882d39ee 

MD5  6a6cb9bfb998e4ca4cbbd8b162821ba9 

BLAKE2256  8bcf5bde3844464f02f0825166c9890ebf5282c8373bc356fd7c01ab2b9474e3 