Skip to main content

Package for calculating GAN metrics using Pytorch

Project description

Pytorch Implementation of Common GAN metrics

Notes

The FID implementation is inspired from pytorch-fid.

Feature

  • Currently, this package supports following metrics:
  • The computation processes of IS and FID are integrated to avoid multiple forward propagations.
  • Read image on the fly for both metrics.

Reproducing Results of Official Implementations

  • CIFAR-10

    Train IS Test IS Train(50k) vs Test(10k)
    FID
    Official 11.24±0.20 10.98±0.22 3.1508
    pytorch-gan-metrics 11.26±0.27 10.97±0.33 3.1517

    Due to the framework difference between PyTorch and TensorFlow, the results are slightly different from official implementations.

Install

python setup.py install

Prepare Statistics for FID

  • Download precalculated statistics for dataset or
  • Calculate statistics for your custom dataset using command line tool
    python -m pytorch_gan_metrics.calc_fid_stats --path path/to/images --output name.npz
    
    See calc_fid_stats.py for implementation details.

Documentation

Using torch.Tensor as images

  • Prepare images in type torch.float32 with shape [N, 3, H, W] and normalized to [0,1].
    from pytorch_gan_metrics import (get_inception_score,
                                     get_fid,
                                     get_inception_score_and_fid)
    images = ... # [N, 3, H, W]
    assert 0 <= images.min() and images.max() <= 1
    # Inception Score
    IS, IS_std = get_inception_score(images)
    # Frechet Inception Distance
    FID = get_fid(images, 'path/to/statistics.npz')
    # Inception Score + Frechet Inception Distance
    (IS, IS_std), FID = get_inception_score_and_fid(
        images, 'path/to/statistics.npz')
    

Using PyTorch DataLoader

  • Use pytorch_gan_metrics.ImageDataset to collect images on disk or use custom dataset which should only return an image in __getitem__.
    from pytorch_gan_metrics import ImageDataset
    
    dataset = ImageDataset(path_to_dir, exts=['png', 'jpg'])
    loader = DataLoader(dataset, batch_size=50, num_workers=4)
    
  • It is possible to wrap a generative model in a dataset to support generating images on the fly. Remember to set num_workers=0 to avoid copying models across multiprocess.
    class GeneratorDataset(Dataset):
        def __init__(self, G, z_dim):
            self.G = G
            self.z_dim = z_dim
        
        def __len__(self):
            return 50000
        
        def __getitem__(self, index):
            return self.G(torch.randn(1, self.z_dim).cuda())[0]
    
    dataset = GeneratorDataset(G, z=128)
    loader = DataLoader(dataset, batch_size=50, num_workers=0)
    
  • Calculate metrics
    from pytorch_gan_metrics import (get_inception_score,
                                     get_fid,
                                     get_inception_score_and_fid)
    # Inception Score
    IS, IS_std = get_inception_score(loader)
    # Frechet Inception Distance
    FID = get_fid(loader, 'path/to/statistics.npz')
    # Inception Score + Frechet Inception Distance
    (IS, IS_std), FID = get_inception_score_and_fid(
        loader, 'path/to/statistics.npz')
    

From directory

  • Calculate metrics for images in the directory.
    from pytorch_gan_metrics import (
        get_inception_score_from_directory,
        get_fid_from_directory,
        get_inception_score_and_fid_from_directory)
    
    IS, IS_std = get_inception_score_from_directory('path/to/images')
    FID = get_fid_from_directory('path/to/images', fid_stats_path)
    (IS, IS_std), FID = get_inception_score_and_fid_from_directory(
        'path/to/images', fid_stats_path)
    

License

This implementation is licensed under the Apache License 2.0.

This implementation is derived from pytorch-fid, licensed under the Apache License 2.0.

FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500

The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_gan_metrics-0.1.0.tar.gz (17.2 kB view details)

Uploaded Source

Built Distributions

pytorch_gan_metrics-0.1.0-py3.6.egg (25.3 kB view details)

Uploaded Source

pytorch_gan_metrics-0.1.0-py3-none-any.whl (16.9 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_gan_metrics-0.1.0.tar.gz.

File metadata

  • Download URL: pytorch_gan_metrics-0.1.0.tar.gz
  • Upload date:
  • Size: 17.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.9

File hashes

Hashes for pytorch_gan_metrics-0.1.0.tar.gz
Algorithm Hash digest
SHA256 4c809ae2047bb8becbf3ae2b32dda26d17990cd1f74765063438b0897a003b3a
MD5 855a89d615cfc09cb6f2bb0aa6013e54
BLAKE2b-256 84288e4627ae01da8acd8d70720a1b60a573f092c15f6af1de76c117ccd08d46

See more details on using hashes here.

File details

Details for the file pytorch_gan_metrics-0.1.0-py3.6.egg.

File metadata

  • Download URL: pytorch_gan_metrics-0.1.0-py3.6.egg
  • Upload date:
  • Size: 25.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.9

File hashes

Hashes for pytorch_gan_metrics-0.1.0-py3.6.egg
Algorithm Hash digest
SHA256 d838bad835d3317ad4a8f6eaa69557b6fed35802256302454ce58a54e83ca98c
MD5 83d8db0caefdfb9ee3588b5c288445e5
BLAKE2b-256 0127befdac4d709263f510b4b88097e7e5efca7a7888430396fbc75376cf867f

See more details on using hashes here.

File details

Details for the file pytorch_gan_metrics-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pytorch_gan_metrics-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 16.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.9

File hashes

Hashes for pytorch_gan_metrics-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e66c90b91654e91484873cddfc2fe5412086810032b04955ae371b4263584303
MD5 77314dde65a995a285537d685bfc8798
BLAKE2b-256 d154c417c61ef3bb6bcc6587100429ad60011fa6510e0e3a8c3b1a59ba125944

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page