Skip to main content

Measures and metrics for image2image tasks. PyTorch.

Project description

PhotoSynthesis.Metrics

CI flake-8 style check CI testing
MIT License LinkedIn PyPI version

PyTorch library with measures and metrics for various image-to-image tasks like denoising, super-resolution, image generation etc. This easy to use yet flexible and extensive library is developed with focus on reliability and reproducibility of results. Use your favourite measures as losses for training neural networks with ready-to-use PyTorch modules.

Getting started

import torch
from photosynthesis_metrics import ssim

prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
ssim_index = ssim(prediction, target, data_range=1.)

Minimal examples

To compute SSIM index as a measure, use lower case function from the library:

import torch
from photosynthesis_metrics import ssim
from typing import Union, Tuple

prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256) 
ssim_index: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = ssim(prediction, target, data_range=1.)

In order to use SSIM as a loss function, use corresponding PyTorch module:

import torch
from photosynthesis_metrics import SSIMLoss

loss = SSIMLoss()
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target, data_range=1.)
output.backward()

To compute MS-SSIM index as a measure, use lower case function from the library:

import torch
from photosynthesis_metrics import multi_scale_ssim

prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256) 
ms_ssim_index: torch.Tensor = multi_scale_ssim(prediction, target, data_range=1.)

In order to use MS-SSIM as a loss function, use corresponding PyTorch module:

import torch
from photosynthesis_metrics import MultiScaleSSIMLoss

loss = MultiScaleSSIMLoss()
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target, data_range=1.)
output.backward()

To compute TV as a measure, use lower case function from the library:

import torch
from photosynthesis_metrics import total_variation

data = torch.rand(3, 3, 256, 256) 
tv: torch.Tensor = total_variation(data)

In order to use TV as a loss function, use corresponding PyTorch module:

import torch
from photosynthesis_metrics import TVLoss

loss = TVLoss()
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
output: torch.Tensor = loss(prediction)
output.backward()

To compute VIF as a measure, use lower case function from the library:

import torch
from photosynthesis_metrics import vif_p

predicted = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
vif: torch.Tensor = vif_p(predicted, target, data_range=1.)

In order to use VIF as a loss function, use corresponding PyTorch class:

import torch
from photosynthesis_metrics import VIFLoss

loss = VIFLoss(sigma_n_sq=2.0, data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
ouput: torch.Tensor = loss(prediction, target)
output.backward()

Note, that VIFLoss returns 1 - VIF value.

Use MSID class to compute MSID score from image features, pre-extracted from some feature extractor network:

import torch
from photosynthesis_metrics import MSID

msid_metric = MSID()
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
msid: torch.Tensor = msid_metric(prediction_feats, target_feats)

If image features are not available, extract them using _compute_feats of MSID class. Please note that _compute_feats consumes a data loader of predefined format.

import torch
from  torch.utils.data import DataLoader
from photosynthesis_metrics import MSID

first_dl, second_dl = DataLoader(), DataLoader()
msid_metric = MSID() 
first_feats = msid_metric._compute_feats(first_dl)
second_feats = msid_metric._compute_feats(second_dl)
msid: torch.Tensor = msid_metric(first_feats, second_feats)

Use FID class to compute FID score from image features, pre-extracted from some feature extractor network:

import torch
from photosynthesis_metrics import FID

fid_metric = FID()
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
msid: torch.Tensor = fid_metric(prediction_feats, target_feats)

If image features are not available, extract them using _compute_feats of FID class. Please note that _compute_feats consumes a data loader of predefined format.

import torch
from  torch.utils.data import DataLoader
from photosynthesis_metrics import FID

first_dl, second_dl = DataLoader(), DataLoader()
fid_metric = FID() 
first_feats = fid_metric._compute_feats(first_dl)
second_feats = fid_metric._compute_feats(second_dl)
msid: torch.Tensor = fid_metric(first_feats, second_feats)

Use KID class to compute KID score from image features, pre-extracted from some feature extractor network:

import torch
from photosynthesis_metrics import KID

kid_metric = KID()
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
msid: torch.Tensor = kid_metric(prediction_feats, target_feats)

If image features are not available, extract them using _compute_feats of KID class. Please note that _compute_feats consumes a data loader of predefined format.

import torch
from  torch.utils.data import DataLoader
from photosynthesis_metrics import KID

first_dl, second_dl = DataLoader(), DataLoader()
kid_metric = KID() 
first_feats = kid_metric._compute_feats(first_dl)
second_feats = kid_metric._compute_feats(second_dl)
msid: torch.Tensor = kid_metric(first_feats, second_feats)

Use inception_score function to compute IS from image features, pre-extracted from some feature extractor network. Note, that we follow recomendations from paper A Note on the Inception Score, which proposed small modification to original algorithm:

import torch
from photosynthesis_metrics import inception_score

prediction_feats = torch.rand(10000, 1024)
mean: torch.Tensor, variance: torch.Tensor = inception_score(prediction_feats, num_splits=10)

To compute difference between IS for 2 sets of image features, use IS class.

import torch
from photosynthesis_metrics import IS


is_metric = IS(distance='l1') 
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
distance: torch.Tensor = is_metric(prediction_feats, target_feats)

Table of Contents

Overview

PhotoSynthesis.Metrics helps you to concentrate on your experiments without the boilerplate code. The library contains a set of measures and metrics that is constantly getting extended. For measures/metrics that can be used as loss functions, corresponding PyTorch modules are implemented.

Installation

$ pip install photosynthesis-metrics

If you want to use the latest features straight from the master, clone the repo:

$ git clone https://github.com/photosynthesis-team/photosynthesis.metrics.git

Roadmap

See the open issues for a list of proposed features and known issues.

Community

Contributing

We appreciate all contributions. If you plan to:

  • contribute back bug-fixes, please do so without any further discussion
  • close one of open issues, please do so if no one has been assigned to it
  • contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us

Please see the contribution guide for more information.

Contact

Sergey Kastryulin - @snk4tr - snk4tr@gmail.com

Project Link: https://github.com/photosynthesis-team/photosynthesis.metrics
PhotoSynthesis Team: https://github.com/photosynthesis-team

Other projects by PhotoSynthesis Team:

Acknowledgements

  • Pavel Parunin - @PavelParunin - idea proposal and development
  • Djamil Zakirov - @zakajd - development
  • Denis Prokopenko - @denproc - development

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for photosynthesis-metrics, version 0.2.0
Filename, size File type Python version Upload date Hashes
Filename, size photosynthesis_metrics-0.2.0-py3-none-any.whl (42.4 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size photosynthesis_metrics-0.2.0.tar.gz (32.4 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page