Vision-aided GAN training
Project description
Vision-aided GAN
video | website | paper
Can the collective knowledge from a large bank of pretrained vision models be leveraged to improve GAN training? If so, with so many models to choose from, which one(s) should be selected, and in what manner are they most effective?
We find that pretrained computer vision models can significantly improve performance when used in an ensemble of discriminators. We propose an effective selection mechanism, by probing the linear separability between real and fake samples in pretrained model embeddings, choosing the most accurate model, and progressively adding it to the discriminator ensemble. Our method can improve GAN training in both limited data and large-scale settings.
Ensembling Off-the-shelf Models for GAN Training
Nupur Kumari, Richard Zhang, Eli Shechtman, Jun-Yan Zhu
In CVPR 2022
Pretrained Models
Our final trained models can be downloaded at this link. For more details on usage please see README in the folder stylegan2 and biggan.
Vision-aided StyleGAN2 training
Please see stylegan2 README for training StyleGAN2 models with our method. This code will reproduce all StyleGAN2 based results from our paper.
Vision-aided Discriminator in a custom GAN model
pip install vision-aided-loss
installing from github: pip install git+https://github.com/nupurkmr9/vision-aided-gan.git or
git clone https://github.com/nupurkmr9/vision-aided-gan.git
cd vision-aided-gan
pip install .
For details on off-the-shelf models please see MODELS.md
import vision_aided_loss
device='cuda'
discr = vision_aided_loss.Discriminator(cv_type='clip', loss_type='multilevel_sigmoid_s', device=device).to(device)
discr.cv_ensemble.requires_grad_(False) # Freeze feature extractor
# Sample images
real = sample_real_image()
fake = G.forward(z)
# Update discriminator discr
lossD = discr(real, for_real=True) + discr(fake, for_real=False)
lossD.backward()
# Update generator G
lossG = discr(fake, for_G=True)
lossG.backward()
# We recommend adding vision-aided adversarial loss after training GAN with standard loss till few warmup_iter.
Arg details:
cv_type: name of the off-the-shelf model from[clip, dino, swin, vgg, det_coco, seg_ade, face_seg, face_normals]. Multiple models can be used with '+' separated model names.output_type: output feature type from off-the-shelf models. should be one of[conv, conv_multi_level]. Supportsconv_multi_levelonly for clip and dino. For multiple models output_type should be '+' separated output_type for each model.diffaug: if True performs DiffAugment on vision-aided discriminator with poilcycolor,translation,cutout. Recommended to keep this as True.num_classes: for conditional training use num_classes>0. Projection discriminator is used similar to BigGAN.loss_type: should be one of[sigmoid, multilevel_sigmoid, sigmoid_s, multilevl_sigmoid_s, hinge, multilevel_hinge]. Appeding_senables label smoothing. If loss_type is None output is a list of logits corresponding to each vision-aided discriminator.device: device for off-the-shelf model weights.
Vision-aided StyleGAN3 training
Please see stylegan3 README for training StyleGAN3 models with our method.
Vision-aided BigGAN training
Please see biggan README for training BigGAN models with our method.
To add you own pretrained Model
create the class file to extract pretrained features as vision_module/<custom_model>.py. Add the class path in the class_name_dict in vision_module.cvmodel.CVBackbone class. Update the architecture of trainable classifier head over pretrained features in vision_module.cv_discriminator. Reinstall library via pip install .
References
@InProceedings{kumari2021ensembling,
title={Ensembling Off-the-shelf Models for GAN Training},
author={Kumari, Nupur and Zhang, Richard and Shechtman, Eli and Zhu, Jun-Yan},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022}
}
Acknowledgments
We thank Muyang Li, Sheng-Yu Wang, Chonghyuk (Andrew) Song for proofreading the draft. We are also grateful to Alexei A. Efros, Sheng-Yu Wang, Taesung Park, and William Peebles for helpful comments and discussion. Our codebase is built on stylegan2-ada-pytorch and DiffAugment.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file vision_aided_loss-0.1.0.tar.gz.
File metadata
- Download URL: vision_aided_loss-0.1.0.tar.gz
- Upload date:
- Size: 34.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f50ec5aceed5f939a2800c9fdd1a45552d5226ab6cfeda3b5e322978b4597110
|
|
| MD5 |
d5fc623ca21635736843c694664bcb57
|
|
| BLAKE2b-256 |
3e2793048d3fd4e58d64dcf79f4328ef9906e264d9810b82cab40bebfdebc547
|
File details
Details for the file vision_aided_loss-0.1.0-py3-none-any.whl.
File metadata
- Download URL: vision_aided_loss-0.1.0-py3-none-any.whl
- Upload date:
- Size: 35.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
282ebc5df1ff212b2f16fcf073e9d6e9c6d1717a10b5a4e60b1e3f7beddcaf1b
|
|
| MD5 |
36f5372b270e794c780a71635e531433
|
|
| BLAKE2b-256 |
7ffda0f62997a6aea521054e276e515dc050b13b8894e2f07c4d8b4dee293d75
|