Skip to main content

Package for calculating GAN metrics using Pytorch

Project description

Pytorch Implementation of Common GAN metrics

PyPI

Notes

The FID implementation is inspired from pytorch-fid.

This repository is developed for personal research. If you think this package can also benefit your life, please feel free to open issues.

Install

pip install pytorch-gan-metrics

Feature

  • Currently, this package supports following metrics:
  • The computation processes of IS and FID are integrated to avoid multiple forward propagations.
  • Support reading image on the fly to avoid out of memory especially for large scale images.
  • Support computation on GPU to speed up some cpu operations such as np.cov and scipy.linalg.sqrtm.

Reproducing Results of Official Implementations on CIFAR-10

Train IS Test IS Train(50k) vs Test(10k)
FID
Official 11.24±0.20 10.98±0.22 3.1508
pytorch-gan-metrics 11.26±0.27 10.97±0.33 3.1517
pytorch-gan-metrics
use_torch=True
11.26±0.21 10.97±0.34 3.1377

The results are slightly different from official implementations due to the framework difference between PyTorch and TensorFlow.

Prepare Statistics for FID

  • Download precalculated statistics or
  • Calculate statistics for your custom dataset using command line tool
    python -m pytorch_gan_metrics.calc_fid_stats --path path/to/images --output name.npz
    
    See calc_fid_stats.py for implementation details.

Documentation

How to use GPU?

pytorch_gan_metrics default uses torch.device('cuda:0') if GPU is available; Otherwise, it uses cpu to calculate inception feature.

Using torch.Tensor as images

  • Prepare images in type torch.float32 with shape [N, 3, H, W] and normalized to [0,1].
    from pytorch_gan_metrics import (get_inception_score,
                                     get_fid,
                                     get_inception_score_and_fid)
    images = ... # [N, 3, H, W]
    assert 0 <= images.min() and images.max() <= 1
    # Inception Score
    IS, IS_std = get_inception_score(images)
    # Frechet Inception Distance
    FID = get_fid(images, 'path/to/statistics.npz')
    # Inception Score + Frechet Inception Distance
    (IS, IS_std), FID = get_inception_score_and_fid(
        images, 'path/to/statistics.npz')
    

Using PyTorch DataLoader to Provide Images

  • Use pytorch_gan_metrics.ImageDataset to collect images on disk or use custom torch.utils.data.Dataset which should only return an image in the end of __getitem__.
    from pytorch_gan_metrics import ImageDataset
    
    dataset = ImageDataset(path_to_dir, exts=['png', 'jpg'])
    loader = DataLoader(dataset, batch_size=50, num_workers=4)
    
  • It is possible to wrap a generative model in a dataset to support generating images on the fly. Remember to set num_workers=0 to avoid copying models across multiprocess.
    class GeneratorDataset(Dataset):
        def __init__(self, G, z_dim):
            self.G = G
            self.z_dim = z_dim
        
        def __len__(self):
            return 50000
        
        def __getitem__(self, index):
            return self.G(torch.randn(1, self.z_dim).cuda())[0]
    
    dataset = GeneratorDataset(G, z=128)
    loader = DataLoader(dataset, batch_size=50, num_workers=0)
    
  • Calculate metrics
    from pytorch_gan_metrics import (get_inception_score,
                                     get_fid,
                                     get_inception_score_and_fid)
    # Inception Score
    IS, IS_std = get_inception_score(loader)
    # Frechet Inception Distance
    FID = get_fid(loader, 'path/to/statistics.npz')
    # Inception Score + Frechet Inception Distance
    (IS, IS_std), FID = get_inception_score_and_fid(
        loader, 'path/to/statistics.npz')
    

Specify Images by a Directory Path

  • Calculate metrics for images in the directory.
    from pytorch_gan_metrics import (
        get_inception_score_from_directory,
        get_fid_from_directory,
        get_inception_score_and_fid_from_directory)
    
    IS, IS_std = get_inception_score_from_directory('path/to/images')
    FID = get_fid_from_directory('path/to/images', fid_stats_path)
    (IS, IS_std), FID = get_inception_score_and_fid_from_directory(
        'path/to/images', fid_stats_path)
    

Set PyTorch as backend

  • Set use_torch=True when calling functions get_* such as get_inception_score, get_fid, etc.
  • WARNING when set use_torch=True, the FID might be nan due to the unstable implementation of matrix sqrt.
  • This option is recommended to be used when evaluate generative models on a server machine which is equipped with high efficiency GPUs while the cpu frequency is low.

License

This implementation is licensed under the Apache License 2.0.

This implementation is derived from pytorch-fid, licensed under the Apache License 2.0.

FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500

The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_gan_metrics-0.3.2.tar.gz (18.0 kB view details)

Uploaded Source

Built Distribution

pytorch_gan_metrics-0.3.2-py3-none-any.whl (17.4 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_gan_metrics-0.3.2.tar.gz.

File metadata

  • Download URL: pytorch_gan_metrics-0.3.2.tar.gz
  • Upload date:
  • Size: 18.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.8.9

File hashes

Hashes for pytorch_gan_metrics-0.3.2.tar.gz
Algorithm Hash digest
SHA256 0c6811e0f3c5368d46b1bb301d38d962a1f6c9f88867b6ef41121c12ff257947
MD5 86844ae67b33e7d9dca5f1a43995fa20
BLAKE2b-256 e94b751c68db45313b3a3fa4079a6594f98d2e169d2cc15c52d33877a366ba6b

See more details on using hashes here.

File details

Details for the file pytorch_gan_metrics-0.3.2-py3-none-any.whl.

File metadata

  • Download URL: pytorch_gan_metrics-0.3.2-py3-none-any.whl
  • Upload date:
  • Size: 17.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.8.9

File hashes

Hashes for pytorch_gan_metrics-0.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 cf91880fc07605d8fe35ad1a8cdeb594fa5c00889c75e4afd7c9bbafbf78256e
MD5 89ab29617719dc4767269e71f97c9041
BLAKE2b-256 25e7447ab5f7cd4855b34523cf52608e9468e7c4634638e2d71097e0147bbf1a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page