Skip to main content

Package for calculating GAN metrics using Pytorch

Project description

This project has been renamed to pytorch-image-generation-metrics. Please visit the new repository here.

Pytorch Implementation of Common GAN metrics

PyPI

Install

pip install pytorch-gan-metrics
  • torch>=1.8.2
  • torchvision>=0.9.2

Quick Start

from pytorch_gan_metrics import get_inception_score, get_fid

images = ... # [N, 3, H, W] normalized to [0, 1]
IS, IS_std = get_inception_score(images)        # Inception Score
FID = get_fid(images, 'path/to/statistics.npz') # Frechet Inception Distance

path/to/statistics.npz is compatiable with official FID implementation.

Notes

The FID implementation is inspired from pytorch-fid.

This repository is developed for personal research. If you think this package can also benefit your life, please feel free to open issues.

Features

  • Currently, this package supports following metrics:
  • The computation procedure of IS and FID are integrated to avoid multiple forward propagations.
  • Support reading images on the fly to avoid out of memory especially for large scale images.
  • Support computation on GPU to speed up some cpu operations such as np.cov and scipy.linalg.sqrtm.

Reproducing Results of Official Implementations on CIFAR-10

Train IS Test IS Train(50k) vs Test(10k)
FID
Official 11.24±0.20 10.98±0.22 3.1508
pytorch-gan-metrics 11.26±0.13 10.96±0.35 3.1518
pytorch-gan-metrics
use_torch=True
torch==2.0.1
11.26±0.15 10.95±0.16 3.1309

The results are slightly different from official implementations due to the framework difference between PyTorch and TensorFlow.

Documentation

Prepare Statistics (for FID)

  • Download precalculated statistics or
  • Calculate statistics for your custom dataset using command line tool
    python -m pytorch_gan_metrics.calc_fid_stats \
        --path path/to/images \
        --stats path/to/statistics.npz
    
    See calc_fid_stats.py for details.

Inception Features

  • When getting IS or FID, the InceptionV3 will be loaded into torch.device('cuda:0') if GPU is availabel; Otherwise, torch.device('cpu') will be used.
  • Change device argument in get_* functions to set torch device.

Using torch.Tensor as images

  • Prepare images in type torch.float32 with shape [N, 3, H, W] and normalized to [0,1].
    from pytorch_gan_metrics import (get_inception_score,
                                     get_fid,
                                     get_inception_score_and_fid)
    images = ... # [N, 3, H, W]
    assert 0 <= images.min() and images.max() <= 1
    # Inception Score
    IS, IS_std = get_inception_score(
        images)
    # Frechet Inception Distance
    FID = get_fid(
        images, 'path/to/statistics.npz')
    # Inception Score & Frechet Inception Distance
    (IS, IS_std), FID = get_inception_score_and_fid(
        images, 'path/to/statistics.npz')
    

Using PyTorch DataLoader to Provide Images

  • Use pytorch_gan_metrics.ImageDataset to collect images on your storage or use your custom torch.utils.data.Dataset.

    from pytorch_gan_metrics import ImageDataset
    
    dataset = ImageDataset(path_to_dir, exts=['png', 'jpg'])
    loader = DataLoader(dataset, batch_size=50, num_workers=4)
    
  • It is possible to wrap a generative model in a dataset to support generating images on the fly. Remember to set num_workers=0 to avoid copying models across multiprocess.

    class GeneratorDataset(Dataset):
        def __init__(self, G, z_dim):
            self.G = G
            self.z_dim = z_dim
    
        def __len__(self):
            return 50000
    
        def __getitem__(self, index):
            return self.G(torch.randn(1, self.z_dim).cuda())[0]
    
    dataset = GeneratorDataset(G, z=128)
    loader = DataLoader(dataset, batch_size=50, num_workers=0)
    
  • Calculate metrics

    from pytorch_gan_metrics import (get_inception_score,
                                     get_fid,
                                     get_inception_score_and_fid)
    # Inception Score
    IS, IS_std = get_inception_score(
        loader)
    # Frechet Inception Distance
    FID = get_fid(
        loader, 'path/to/statistics.npz')
    # Inception Score & Frechet Inception Distance
    (IS, IS_std), FID = get_inception_score_and_fid(
        loader, 'path/to/statistics.npz')
    

Load Images from a Directory

  • Calculate metrics for images in a directory and its subfolders.
    from pytorch_gan_metrics import (
        get_inception_score_from_directory,
        get_fid_from_directory,
        get_inception_score_and_fid_from_directory)
    
    IS, IS_std = get_inception_score_from_directory(
        'path/to/images')
    FID = get_fid_from_directory(
        'path/to/images', 'path/to/statistics.npz')
    (IS, IS_std), FID = get_inception_score_and_fid_from_directory(
        'path/to/images', 'path/to/statistics.npz')
    

Accelerating Matrix Computation by PyTorch

  • Set use_torch=True when calling functions get_* such as get_inception_score, get_fid, etc.

  • WARNING when use_torch=True is used, the FID might be nan due to the unstable implementation of matrix sqrt.

  • This option is recommended to be used when evaluating generative models on a server which is equipped with high efficiency GPUs while the cpu frequency is low.

Tested Versions

  • python 3.9 + torch 1.8.2 + CUDA 10.2
  • python 3.9 + torch 1.11.0 + CUDA 10.2
  • python 3.9 + torch 1.12.1 + CUDA 10.2

License

This implementation is licensed under the Apache License 2.0.

This implementation is derived from pytorch-fid, licensed under the Apache License 2.0.

FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500

The original implementation of FID is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_gan_metrics-0.5.4.tar.gz (8.0 kB view details)

Uploaded Source

Built Distribution

pytorch_gan_metrics-0.5.4-py3-none-any.whl (7.9 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_gan_metrics-0.5.4.tar.gz.

File metadata

  • Download URL: pytorch_gan_metrics-0.5.4.tar.gz
  • Upload date:
  • Size: 8.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.13

File hashes

Hashes for pytorch_gan_metrics-0.5.4.tar.gz
Algorithm Hash digest
SHA256 2ef106f9cbfeed08bffbb2314428617cb9f57f24d8ba4299b8bfcecf7386df03
MD5 c283090d22f50cb6e2ef543ff6f2fb03
BLAKE2b-256 90470bbc7d88128f62f3afef335c60d5acc6facc5d8a9967e8e156ef3ad4e68f

See more details on using hashes here.

File details

Details for the file pytorch_gan_metrics-0.5.4-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_gan_metrics-0.5.4-py3-none-any.whl
Algorithm Hash digest
SHA256 17ca6e170f2823512e14b52f8e544ec56e69b4bd8b0c022af18884977e726a39
MD5 11485350e99267f2a78b0be1d42177bf
BLAKE2b-256 5737ee285edd5a0b55ef1d46e99db2b62ec82c062594ca70af902b61866f7a70

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page