Package for calculating GAN metrics using Pytorch
Project description
Pytorch Implementation of Common GAN metrics
Notes
The FID implementation is inspired from pytorch-fid.
Feature
- Currently, this package supports following metrics:
- Inception Score (IS)
- Fréchet Inception Distance (FID)
- The computation processes of IS and FID are integrated to avoid multiple forward propagations.
- Read image on the fly for both metrics.
Reproducing Results of Official Implementations
-
CIFAR-10
Train IS Test IS Train(50k) vs Test(10k)
FIDOfficial 11.24±0.20 10.98±0.22 3.1508 pytorch-gan-metrics 11.26±0.27 10.97±0.33 3.1517 Due to the framework difference between PyTorch and TensorFlow, the results are slightly different from official implementations.
Install
python setup.py install
Prepare Statistics for FID
- Download precalculated statistics for dataset or
- Calculate statistics for your custom dataset using command line tool
python -m pytorch_gan_metrics.calc_fid_stats --path path/to/images --output name.npz
See calc_fid_stats.py for implementation details.
Documentation
Using torch.Tensor
as images
- Prepare images in type
torch.float32
with shape[N, 3, H, W]
and normalized to[0,1]
.from pytorch_gan_metrics import (get_inception_score, get_fid, get_inception_score_and_fid) images = ... # [N, 3, H, W] assert 0 <= images.min() and images.max() <= 1 # Inception Score IS, IS_std = get_inception_score(images) # Frechet Inception Distance FID = get_fid(images, 'path/to/statistics.npz') # Inception Score + Frechet Inception Distance (IS, IS_std), FID = get_inception_score_and_fid( images, 'path/to/statistics.npz')
Using PyTorch DataLoader
- Use
pytorch_gan_metrics.ImageDataset
to collect images on disk or use custom dataset which should only return an image in__getitem__
.from pytorch_gan_metrics import ImageDataset dataset = ImageDataset(path_to_dir, exts=['png', 'jpg']) loader = DataLoader(dataset, batch_size=50, num_workers=4)
- It is possible to wrap a generative model in a dataset to support generating images on the fly. Remember to set
num_workers=0
to avoid copying models across multiprocess.class GeneratorDataset(Dataset): def __init__(self, G, z_dim): self.G = G self.z_dim = z_dim def __len__(self): return 50000 def __getitem__(self, index): return self.G(torch.randn(1, self.z_dim).cuda())[0] dataset = GeneratorDataset(G, z=128) loader = DataLoader(dataset, batch_size=50, num_workers=0)
- Calculate metrics
from pytorch_gan_metrics import (get_inception_score, get_fid, get_inception_score_and_fid) # Inception Score IS, IS_std = get_inception_score(loader) # Frechet Inception Distance FID = get_fid(loader, 'path/to/statistics.npz') # Inception Score + Frechet Inception Distance (IS, IS_std), FID = get_inception_score_and_fid( loader, 'path/to/statistics.npz')
From directory
- Calculate metrics for images in the directory.
from pytorch_gan_metrics import ( get_inception_score_from_directory, get_fid_from_directory, get_inception_score_and_fid_from_directory) IS, IS_std = get_inception_score_from_directory('path/to/images') FID = get_fid_from_directory('path/to/images', fid_stats_path) (IS, IS_std), FID = get_inception_score_and_fid_from_directory( 'path/to/images', fid_stats_path)
License
This implementation is licensed under the Apache License 2.0.
This implementation is derived from pytorch-fid, licensed under the Apache License 2.0.
FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500
The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
File details
Details for the file pytorch_gan_metrics-0.1.0.tar.gz
.
File metadata
- Download URL: pytorch_gan_metrics-0.1.0.tar.gz
- Upload date:
- Size: 17.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4c809ae2047bb8becbf3ae2b32dda26d17990cd1f74765063438b0897a003b3a |
|
MD5 | 855a89d615cfc09cb6f2bb0aa6013e54 |
|
BLAKE2b-256 | 84288e4627ae01da8acd8d70720a1b60a573f092c15f6af1de76c117ccd08d46 |
File details
Details for the file pytorch_gan_metrics-0.1.0-py3.6.egg
.
File metadata
- Download URL: pytorch_gan_metrics-0.1.0-py3.6.egg
- Upload date:
- Size: 25.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d838bad835d3317ad4a8f6eaa69557b6fed35802256302454ce58a54e83ca98c |
|
MD5 | 83d8db0caefdfb9ee3588b5c288445e5 |
|
BLAKE2b-256 | 0127befdac4d709263f510b4b88097e7e5efca7a7888430396fbc75376cf867f |
File details
Details for the file pytorch_gan_metrics-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: pytorch_gan_metrics-0.1.0-py3-none-any.whl
- Upload date:
- Size: 16.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e66c90b91654e91484873cddfc2fe5412086810032b04955ae371b4263584303 |
|
MD5 | 77314dde65a995a285537d685bfc8798 |
|
BLAKE2b-256 | d154c417c61ef3bb6bcc6587100429ad60011fa6510e0e3a8c3b1a59ba125944 |