Package for calculating image generation metrics using Pytorch
Project description
Pytorch Implementation of Common Image Generation Metrics
Installation
pip install pytorch-image-generation-metrics
Quick Start
from pytorch_image_generation_metrics import get_inception_score, get_fid
images = ... # [N, 3, H, W] normalized to [0, 1]
IS, IS_std = get_inception_score(images) # Inception Score
FID = get_fid(images, 'path/to/fid_ref.npz') # Frechet Inception Distance
The file path/to/fid_ref.npz
is compatiable with the official FID implementation.
Notes
The FID implementation is inspired by pytorch-fid.
This repository is developed for personal research. If you find this package useful, please feel free to open issues.
Features
- Currently, this package supports the following metrics:
- Inception Score (IS)
- Fréchet Inception Distance (FID)
- The computation procedures for IS and FID are integrated to avoid multiple forward passes.
- Supports reading images on the fly to prevent out-of-memory issues, especially for large-scale images.
- Supports computation on GPU to speed up some CPU operations, such as
np.cov
andscipy.linalg.sqrtm
.
Reproducing Results of Official Implementations on CIFAR-10
Train IS | Test IS | Train(50k) vs Test(10k) FID |
|
---|---|---|---|
Official | 11.24±0.20 | 10.98±0.22 | 3.1508 |
ours | 11.26±0.13 | 10.97±0.19 | 3.1525 |
ours use_torch=True |
11.26±0.15 | 10.97±0.20 | 3.1457 |
The results differ slightly from the official implementations due to the framework differences between PyTorch and TensorFlow.
Documentation
Prepare Statistical Reference for FID
- Download the pre-calculated reference, or
- Calculate the statistical reference for your custom dataset using the command-line tool:
python -m pytorch_image_generation_metrics.fid_ref \ --path path/to/images \ --output path/to/fid_ref.npz
See fid_ref.py for details.
Inception Features
- When getting IS or FID, the
InceptionV3
model will be loaded intotorch.device('cuda:0')
by default. - Change the
device
argument in theget_*
functions to set the torch device.
Using torch.Tensor
as images
- Prepare images as
torch.float32
tensors with shape[N, 3, H, W]
, normalized to[0,1]
.from pytorch_image_generation_metrics import ( get_inception_score, get_fid, get_inception_score_and_fid ) images = ... # [N, 3, H, W] assert 0 <= images.min() and images.max() <= 1 # Inception Score IS, IS_std = get_inception_score( images) # Frechet Inception Distance FID = get_fid( images, 'path/to/fid_ref.npz') # Inception Score & Frechet Inception Distance (IS, IS_std), FID = get_inception_score_and_fid( images, 'path/to/fid_ref.npz')
Using PyTorch DataLoader to Provide Images
-
Use
pytorch_image_generation_metrics.ImageDataset
to collect images from your storage or use your customtorch.utils.data.Dataset
.from pytorch_image_generation_metrics import ImageDataset from torch.utils.data import DataLoader dataset = ImageDataset(path_to_dir, exts=['png', 'jpg']) loader = DataLoader(dataset, batch_size=50, num_workers=4)
You can wrap a generative model in a dataset to support generating images on the fly.
class GeneratorDataset(Dataset): def __init__(self, G, noise_dim): self.G = G self.noise_dim = noise_dim def __len__(self): return 50000 def __getitem__(self, index): return self.G(torch.randn(1, self.noise_dim)) dataset = GeneratorDataset(G, noise_dim=128) loader = DataLoader(dataset, batch_size=50, num_workers=0)
-
Calculate metrics
from pytorch_image_generation_metrics import ( get_inception_score, get_fid, get_inception_score_and_fid ) # Inception Score IS, IS_std = get_inception_score( loader) # Frechet Inception Distance FID = get_fid( loader, 'path/to/fid_ref.npz') # Inception Score & Frechet Inception Distance (IS, IS_std), FID = get_inception_score_and_fid( loader, 'path/to/fid_ref.npz')
Load Images from a Directory
- Calculate metrics for images in a directory and its subfolders.
from pytorch_image_generation_metrics import ( get_inception_score_from_directory, get_fid_from_directory, get_inception_score_and_fid_from_directory) IS, IS_std = get_inception_score_from_directory( 'path/to/images') FID = get_fid_from_directory( 'path/to/images', 'path/to/fid_ref.npz') (IS, IS_std), FID = get_inception_score_and_fid_from_directory( 'path/to/images', 'path/to/fid_ref.npz')
Accelerating Matrix Computation with PyTorch
-
Set
use_torch=True
when calling functions likeget_inception_score
,get_fid
, etc. -
WARNING: when
use_torch=True
is used, the FID might benan
due to the unstable implementation of matrix sqrt root.
Tested Versions
python 3.9 + torch 1.13.1 + CUDA 11.7
python 3.9 + torch 2.3.0 + CUDA 12.1
License
This implementation is licensed under the Apache License 2.0.
This implementation is derived from pytorch-fid, licensed under the Apache License 2.0.
FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500
The original implementation of FID is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytorch_image_generation_metrics-0.6.1.tar.gz
.
File metadata
- Download URL: pytorch_image_generation_metrics-0.6.1.tar.gz
- Upload date:
- Size: 21.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2e81f81de0fe331e7c769b203695c2a8c60ace89b6229141c6774a31a49c13c6 |
|
MD5 | 0e05d74ec1d08d6fd30ca9f5e6616507 |
|
BLAKE2b-256 | d4eeed446ad931b6fb0178739ef79524493247426fa3a40dccc4524ec411aa10 |
File details
Details for the file pytorch_image_generation_metrics-0.6.1-py3-none-any.whl
.
File metadata
- Download URL: pytorch_image_generation_metrics-0.6.1-py3-none-any.whl
- Upload date:
- Size: 25.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6635307f4878866b25c835baf919253f3baea3e7bc55153f5e7e1e2a697d0073 |
|
MD5 | 2a5413d2aaf93281f21fbe4bff5ee40c |
|
BLAKE2b-256 | df379b1376b1b2bebcdb3b5a45ff4a70313fa2cc8b22caeb217ed4b3c4d55116 |