GAN Evaluator for IS and FID
Project description
GAN Evaluator for Inception Score (IS) and Frechet Inception Distance (FID) in PyTorch
Contributors: Chen Liu (chen.liu.cl2482@yale.edu), Alex Wong (alex.wong@yale.edu)
Please kindly Star this repo for better reach if you find it useful. Let's help out the community!
Main Contributions
- We created a GAN evaluator for IS and FID that
- is easy to use,
- accepts data as either dataloaders or individual batches, and
- supports on-the-fly evaluation during training.
- We provided a simple demo script to demonstrate one common use case.
NEWS
[Feb 18, 2023]
Now available on PyPI! Now you can pip install it to your desired environment via:
pip install gan-evaluator
And in your Python project, wherever you need the GAN_Evaluator
, you can import via:
from gan_evaluator import GAN_Evaluator
NOTE 1: You no longer need to copy any code from this repo in order to use GAN_Evalutor
! At this point, the primary purpose of this repo is description and demonstration. With that said, you surely can clone this repo and try out the demo script. Also, you may find it easier to copy and modify the code if you want slightly different behaviors.
NOTE 2: During pip install gan-evaluator
, the dependencies of GAN_Evaluator
(but not of the demo script) are also installed.
Demo Script: Use DCGAN to generate SVHN digits
The script can be found in src/train_dcgan_svhn.py
-
Usage from the demo script, to give you a taste.
Declaration
evaluator = GAN_Evaluator(device=device, num_images_real=len(train_loader.dataset), num_images_fake=len(train_loader.dataset))
Before traing loop
evaluator.load_all_real_imgs(real_loader=train_loader, idx_in_loader=0)
Inside traing loop
if shall_plot: IS_mean, IS_std, FID = evaluator.fill_fake_img_batch(fake_batch=x_fake) else: evaluator.fill_fake_img_batch(fake_batch=x_fake, return_results=False)
After each epoch of training
evaluator.clear_fake_imgs()
-
Some visualizations of the demo script:
- Real (top) and Generated (bottom) images.
- IS and FID curves.
Details: The Evaluator for IS and FID
Introduction to the Evaluator
More details can be found in src/utils/gan_evaluator.py/GAN_Evaluator
.
This evaluator computes the following metrics:
- Inception Score (IS)
- Frechet Inception Distance (FID)
This evaluator will take in the real images and the fake/generated images.
Then it will compute the activations from the real and fake images as well as the
predictions from the fake images.
The (fake) predictions will be used to compute IS, while
the (real, fake) activations will be used to compute FID.
If input image resolution < 75 x 75, we will upsample the image to accommodate Inception v3.
The real and fake images can be provided to this evaluator in either of the following formats:
1. dataloader
`load_all_real_imgs`
`load_all_fake_imgs`
2. per-batch
`fill_real_img_batch`
`fill_fake_img_batch`
!!! Please note: the latest IS and FID will be returned upon completion of either of the following:
`load_all_fake_imgs`
`fill_fake_img_batch`
Return format:
(IS mean, IS std, FID)
*So please make sure you load real images before the fake images.*
Common Use Cases:
1. For the purpose of on-the-fly evaluation during GAN training:
We recommend pre-loading the real images using the dataloader format, and
populate the fake images using the per-batch format as training goes on.
- At the end of each epoch, you can clean the fake images using:
`clear_fake_imgs`
- In *unusual* cases where your real images change (such as in progressive growing GANs),
you may want to clear the real images. You can do so via:
`clear_real_imgs`
2. For the purpose of offline evaluation of a saved dataset:
We recommend pre-loading the real images and fake images.
Repository Hierarchy
GAN-evaluator
├── config
| └── `dcgan_svhn.yaml`
├── data (*)
├── debug_plot (*)
├── logs (*)
└── src
├── utils
| ├── `gan_evaluator.py`: THIS CONTAINS OUR `GAN_Evaluator`.
| └── other utility files.
└── `train_dcgan_svhn.py`: our demo script.
Folders marked with (*), if not exist, will be created automatically when you run train_dcgan_svhn.py
.
Usage
- To run our demo script, do the following after activating the proper environment.
git clone git@github.com:ChenLiu-1996/GAN-evaluator.git
cd src
python train_dcgan_svhn.py --config ../config/dcgan_svhn.yaml
-
To integrate our evaluator into your existing project, you can simply copy
src/utils/gan_evaluator.py
to an appropriate folder in your project, and importGAN_Evaluator
to wherever you find necessary. Update: Now you can directly install via pip! -
We will add our citation bibtex, and we would appreciate if you reference our work in case this repository helps you in your research.
Citation
To be added.
Environement Setup
Packages Needed
The GAN_Evaluator
module itself only uses numpy
, scipy
, torch
, torchvision
, and (for aesthetics) tqdm
.
To run the example script, it additionally requires matplotlib
, argparse
, and yaml
.
On our Yale Vision Lab server
-
There is a virtualenv ready to use, located at
/media/home/chliu/.virtualenv/mondi-image-gen/
. -
Alternatively, you can start from an existing environment "torch191-py38env", and install the following packages:
python3 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
python3 -m pip install wget gdown numpy matplotlib pyyaml click scipy yacs scikit-learn
If you see error messages such as Failed to build CUDA kernels for bias_act.
, you can fix it with:
python3 -m pip install ninja
Acknowledgements
- The code for the
GAN_Evaluator
(specifically, the computation of IS and FID) is inspired by:- https://github.com/sbarratt/inception-score-pytorch/blob/master/inception_score.py
- https://www.kaggle.com/code/ibtesama/gan-in-pytorch-with-fid
- https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
- Note: We did not validate the "mathmetical correctness" in their computations. Please use mindfully.
- The code for the demo script (specifically, architecture and training of DCGAN) is inspired by:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file gan-evaluator-1.15.tar.gz
.
File metadata
- Download URL: gan-evaluator-1.15.tar.gz
- Upload date:
- Size: 8.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 07609c2965a98baa275c26056532bab6efbd1a4c69d0f31efd4ac8059403463f |
|
MD5 | f1dafaff2b2888e4aa4e03cab004e264 |
|
BLAKE2b-256 | 86dfa7e83c42a31c243fab4ef99a66b20b0292239a2ae1799692743ab09e58cb |
File details
Details for the file gan_evaluator-1.15-py3-none-any.whl
.
File metadata
- Download URL: gan_evaluator-1.15-py3-none-any.whl
- Upload date:
- Size: 10.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | aca501d8378756d8ab9b2c29f9b774a727072d034757adb9708f8b13aeae49d1 |
|
MD5 | ef8041618e5c5700a45b5255824d8f7c |
|
BLAKE2b-256 | ee15b3ee8296d1d1bb9d2689f43d272b76b2e3b2c23d466810aedebda1ec74dd |