Package for calculating GAN metrics using Pytorch
Project description
Pytorch Implementation of Common GAN metrics
Install
pip install pytorch-gan-metrics
torch>=1.8.2
torchvision>=0.9.2
Quick Start
from pytorch_gan_metrics import get_inception_score, get_fid
images = ... # [N, 3, H, W] normalized to [0, 1]
IS, IS_std = get_inception_score(images) # Inception Score
FID = get_fid(images, 'path/to/statistics.npz') # Frechet Inception Distance
path/to/statistics.npz
is compatiable with official FID implementation.
Notes
The FID implementation is inspired from pytorch-fid.
This repository is developed for personal research. If you think this package can also benefit your life, please feel free to open issues.
Features
- Currently, this package supports following metrics:
- Inception Score (IS)
- Fréchet Inception Distance (FID)
- The computation procedure of IS and FID are integrated to avoid multiple forward propagations.
- Support reading images on the fly to avoid out of memory especially for large scale images.
- Support computation on GPU to speed up some cpu operations such as
np.cov
andscipy.linalg.sqrtm
.
Reproducing Results of Official Implementations on CIFAR-10
Train IS | Test IS | Train(50k) vs Test(10k) FID |
|
---|---|---|---|
Official | 11.24±0.20 | 10.98±0.22 | 3.1508 |
pytorch-gan-metrics | 11.26±0.14 | 10.96±0.35 | 3.1518 |
pytorch-gan-metricsuse_torch=True |
11.26±0.15 | 10.96±0.19 | 3.1509 |
The results are slightly different from official implementations due to the framework difference between PyTorch and TensorFlow.
Documentation
Prepare Statistics (for FID)
- Download precalculated statistics or
- Calculate statistics for your custom dataset using command line tool
python -m pytorch_gan_metrics.calc_fid_stats \ --path path/to/images \ --stats path/to/statistics.npz
See calc_fid_stats.py for details.
Inception Features
- When getting IS or FID, the
InceptionV3
will be loaded intotorch.device('cuda:0')
if GPU is availabel; Otherwise,torch.device('cpu')
will be used. - Change
device
argument inget_*
functions to set torch device.
Using torch.Tensor
as images
- Prepare images in type
torch.float32
with shape[N, 3, H, W]
and normalized to[0,1]
.from pytorch_gan_metrics import (get_inception_score, get_fid, get_inception_score_and_fid) images = ... # [N, 3, H, W] assert 0 <= images.min() and images.max() <= 1 # Inception Score IS, IS_std = get_inception_score( images) # Frechet Inception Distance FID = get_fid( images, 'path/to/statistics.npz') # Inception Score & Frechet Inception Distance (IS, IS_std), FID = get_inception_score_and_fid( images, 'path/to/statistics.npz')
Using PyTorch DataLoader to Provide Images
-
Use
pytorch_gan_metrics.ImageDataset
to collect images on your storage or use your customtorch.utils.data.Dataset
.from pytorch_gan_metrics import ImageDataset dataset = ImageDataset(path_to_dir, exts=['png', 'jpg']) loader = DataLoader(dataset, batch_size=50, num_workers=4)
-
It is possible to wrap a generative model in a dataset to support generating images on the fly. Remember to set
num_workers=0
to avoid copying models across multiprocess.class GeneratorDataset(Dataset): def __init__(self, G, z_dim): self.G = G self.z_dim = z_dim def __len__(self): return 50000 def __getitem__(self, index): return self.G(torch.randn(1, self.z_dim).cuda())[0] dataset = GeneratorDataset(G, z=128) loader = DataLoader(dataset, batch_size=50, num_workers=0)
-
Calculate metrics
from pytorch_gan_metrics import (get_inception_score, get_fid, get_inception_score_and_fid) # Inception Score IS, IS_std = get_inception_score( loader) # Frechet Inception Distance FID = get_fid( loader, 'path/to/statistics.npz') # Inception Score & Frechet Inception Distance (IS, IS_std), FID = get_inception_score_and_fid( loader, 'path/to/statistics.npz')
Load Images from a Directory
- Calculate metrics for images in a directory and its subfolders.
from pytorch_gan_metrics import ( get_inception_score_from_directory, get_fid_from_directory, get_inception_score_and_fid_from_directory) IS, IS_std = get_inception_score_from_directory( 'path/to/images') FID = get_fid_from_directory( 'path/to/images', 'path/to/statistics.npz') (IS, IS_std), FID = get_inception_score_and_fid_from_directory( 'path/to/images', 'path/to/statistics.npz')
Accelerating Matrix Computation by PyTorch
-
Set
use_torch=True
when calling functionsget_*
such asget_inception_score
,get_fid
, etc. -
WARNING when
use_torch=True
is used, the FID might benan
due to the unstable implementation of matrix sqrt. -
This option is recommended to be used when evaluating generative models on a server which is equipped with high efficiency GPUs while the cpu frequency is low.
Tested Versions
python 3.9 + torch 1.8.2 + CUDA 10.2
python 3.9 + torch 1.11.0 + CUDA 10.2
python 3.9 + torch 1.12.1 + CUDA 10.2
License
This implementation is licensed under the Apache License 2.0.
This implementation is derived from pytorch-fid, licensed under the Apache License 2.0.
FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500
The original implementation of FID is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytorch_gan_metrics-0.5.3.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | df5a32f57f5b669dfcc8cc008221ebc9022234f3e3a2dd199f66d71756408102 |
|
MD5 | 5183263aec8acd80c00dc7224c4ad21e |
|
BLAKE2b-256 | 561f92ab220f701a35d7c22c595ad2f280913c10d6316249eed4d8cd93fa6c46 |
Hashes for pytorch_gan_metrics-0.5.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e70443cf6ec8b6f74acb1a9643c406d2134de771c728ab2b08dc7a137e617b92 |
|
MD5 | d7d51087239b0fd1c087718ca9e782d0 |
|
BLAKE2b-256 | 02832c65b41b01ec26fcd2aea0943f4e2373759314776b9f0a2db9ee59e459cc |