Package for calculating GAN metrics using Pytorch
Project description
Pytorch Implementation of Common GAN metrics
Notes
The FID implementation is inspired from pytorch-fid.
Feature
- Currently, this package supports following metrics:
- Inception Score (IS)
- Fréchet Inception Distance (FID)
- The computation processes of IS and FID are integrated to avoid multiple forward propagations.
- Read image on the fly for both metrics.
Reproducing Results of Official Implementations
-
CIFAR-10
Train IS Test IS Train(50k) vs Test(10k)
FIDOfficial 11.24±0.20 10.98±0.22 3.1508 pytorch-gan-metrics 11.26±0.27 10.97±0.33 3.1517 Due to the framework difference between PyTorch and TensorFlow, the results are slightly different from official implementations.
Install
python setup.py install
Prepare Statistics for FID
- Download precalculated statistics for dataset or
- Calculate statistics for your custom dataset using command line tool
python -m pytorch_gan_metrics.calc_fid_stats --path path/to/images --output name.npz
See calc_fid_stats.py for implementation details.
Documentation
Using torch.Tensor
as images
- Prepare images in type
torch.float32
with shape[N, 3, H, W]
and normalized to[0,1]
.from pytorch_gan_metrics import (get_inception_score, get_fid, get_inception_score_and_fid) images = ... # [N, 3, H, W] assert 0 <= images.min() and images.max() <= 1 # Inception Score IS, IS_std = get_inception_score(images) # Frechet Inception Distance FID = get_fid(images, 'path/to/statistics.npz') # Inception Score + Frechet Inception Distance (IS, IS_std), FID = get_inception_score_and_fid( images, 'path/to/statistics.npz')
Using PyTorch DataLoader
- Use
pytorch_gan_metrics.ImageDataset
to collect images on disk or use custom dataset which should only return an image in__getitem__
.from pytorch_gan_metrics import ImageDataset dataset = ImageDataset(path_to_dir, exts=['png', 'jpg']) loader = DataLoader(dataset, batch_size=50, num_workers=4)
- It is possible to wrap a generative model in a dataset to support generating images on the fly. Remember to set
num_workers=0
to avoid copying models across multiprocess.class GeneratorDataset(Dataset): def __init__(self, G, z_dim): self.G = G self.z_dim = z_dim def __len__(self): return 50000 def __getitem__(self, index): return self.G(torch.randn(1, self.z_dim).cuda())[0] dataset = GeneratorDataset(G, z=128) loader = DataLoader(dataset, batch_size=50, num_workers=0)
- Calculate metrics
from pytorch_gan_metrics import (get_inception_score, get_fid, get_inception_score_and_fid) # Inception Score IS, IS_std = get_inception_score(loader) # Frechet Inception Distance FID = get_fid(loader, 'path/to/statistics.npz') # Inception Score + Frechet Inception Distance (IS, IS_std), FID = get_inception_score_and_fid( loader, 'path/to/statistics.npz')
From directory
- Calculate metrics for images in the directory.
from pytorch_gan_metrics import ( get_inception_score_from_directory, get_fid_from_directory, get_inception_score_and_fid_from_directory) IS, IS_std = get_inception_score_from_directory('path/to/images') FID = get_fid_from_directory('path/to/images', fid_stats_path) (IS, IS_std), FID = get_inception_score_and_fid_from_directory( 'path/to/images', fid_stats_path)
License
This implementation is licensed under the Apache License 2.0.
This implementation is derived from pytorch-fid, licensed under the Apache License 2.0.
FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500
The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for pytorch_gan_metrics-0.1.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4c809ae2047bb8becbf3ae2b32dda26d17990cd1f74765063438b0897a003b3a |
|
MD5 | 855a89d615cfc09cb6f2bb0aa6013e54 |
|
BLAKE2b-256 | 84288e4627ae01da8acd8d70720a1b60a573f092c15f6af1de76c117ccd08d46 |
Hashes for pytorch_gan_metrics-0.1.0-py3.6.egg
Algorithm | Hash digest | |
---|---|---|
SHA256 | d838bad835d3317ad4a8f6eaa69557b6fed35802256302454ce58a54e83ca98c |
|
MD5 | 83d8db0caefdfb9ee3588b5c288445e5 |
|
BLAKE2b-256 | 0127befdac4d709263f510b4b88097e7e5efca7a7888430396fbc75376cf867f |
Hashes for pytorch_gan_metrics-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e66c90b91654e91484873cddfc2fe5412086810032b04955ae371b4263584303 |
|
MD5 | 77314dde65a995a285537d685bfc8798 |
|
BLAKE2b-256 | d154c417c61ef3bb6bcc6587100429ad60011fa6510e0e3a8c3b1a59ba125944 |