Skip to main content

Vision-aided GAN training

Project description

Vision-aided GAN

PWC PWC PWC

video | website | paper

Can the collective knowledge from a large bank of pretrained vision models be leveraged to improve GAN training? If so, with so many models to choose from, which one(s) should be selected, and in what manner are they most effective?

We find that pretrained computer vision models can significantly improve performance when used in an ensemble of discriminators. We propose an effective selection mechanism, by probing the linear separability between real and fake samples in pretrained model embeddings, choosing the most accurate model, and progressively adding it to the discriminator ensemble. Our method can improve GAN training in both limited data and large-scale settings.

Ensembling Off-the-shelf Models for GAN Training
Nupur Kumari, Richard Zhang, Eli Shechtman, Jun-Yan Zhu
In CVPR 2022

Pretrained Models

Our final trained models can be downloaded at this link. For more details on usage please see README in the folder stylegan2 and biggan.

Vision-aided StyleGAN2 training

Please see stylegan2 README for training StyleGAN2 models with our method. This code will reproduce all StyleGAN2 based results from our paper.

Vision-aided Discriminator in a custom GAN model

pip install vision-aided-loss

installing from github: pip install git+https://github.com/nupurkmr9/vision-aided-gan.git or

git clone https://github.com/nupurkmr9/vision-aided-gan.git
cd vision-aided-gan
pip install .

For details on off-the-shelf models please see MODELS.md

import vision_aided_loss

device='cuda'
discr = vision_aided_loss.Discriminator(cv_type='clip', loss_type='multilevel_sigmoid_s', device=device).to(device)
discr.cv_ensemble.requires_grad_(False) # Freeze feature extractor

# Sample images
real = sample_real_image()
fake = G.forward(z)

# Update discriminator discr
lossD = discr(real, for_real=True) + discr(fake, for_real=False)
lossD.backward()

# Update generator G
lossG = discr(fake, for_G=True)
lossG.backward()

# We recommend adding vision-aided adversarial loss after training GAN with standard loss till few warmup_iter.

Arg details:

  • cv_type: name of the off-the-shelf model from [clip, dino, swin, vgg, det_coco, seg_ade, face_seg, face_normals]. Multiple models can be used with '+' separated model names.
  • output_type: output feature type from off-the-shelf models. should be one of [conv, conv_multi_level]. Supports conv_multi_level only for clip and dino. For multiple models output_type should be '+' separated output_type for each model.
  • diffaug: if True performs DiffAugment on vision-aided discriminator with poilcy color,translation,cutout. Recommended to keep this as True.
  • num_classes: for conditional training use num_classes>0. Projection discriminator is used similar to BigGAN.
  • loss_type: should be one of [sigmoid, multilevel_sigmoid, sigmoid_s, multilevl_sigmoid_s, hinge, multilevel_hinge]. Appeding _s enables label smoothing. If loss_type is None output is a list of logits corresponding to each vision-aided discriminator.
  • device: device for off-the-shelf model weights.

Vision-aided StyleGAN3 training

Please see stylegan3 README for training StyleGAN3 models with our method.

Vision-aided BigGAN training

Please see biggan README for training BigGAN models with our method.

To add you own pretrained Model

create the class file to extract pretrained features as vision_module/<custom_model>.py. Add the class path in the class_name_dict in vision_module.cvmodel.CVBackbone class. Update the architecture of trainable classifier head over pretrained features in vision_module.cv_discriminator. Reinstall library via pip install .

References

@InProceedings{kumari2021ensembling,
  title={Ensembling Off-the-shelf Models for GAN Training},
  author={Kumari, Nupur and Zhang, Richard and Shechtman, Eli and Zhu, Jun-Yan},
  booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  month     = {June},
  year      = {2022}
}

Acknowledgments

We thank Muyang Li, Sheng-Yu Wang, Chonghyuk (Andrew) Song for proofreading the draft. We are also grateful to Alexei A. Efros, Sheng-Yu Wang, Taesung Park, and William Peebles for helpful comments and discussion. Our codebase is built on stylegan2-ada-pytorch and DiffAugment.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vision_aided_loss-0.1.0.tar.gz (34.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vision_aided_loss-0.1.0-py3-none-any.whl (35.7 kB view details)

Uploaded Python 3

File details

Details for the file vision_aided_loss-0.1.0.tar.gz.

File metadata

  • Download URL: vision_aided_loss-0.1.0.tar.gz
  • Upload date:
  • Size: 34.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.13

File hashes

Hashes for vision_aided_loss-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f50ec5aceed5f939a2800c9fdd1a45552d5226ab6cfeda3b5e322978b4597110
MD5 d5fc623ca21635736843c694664bcb57
BLAKE2b-256 3e2793048d3fd4e58d64dcf79f4328ef9906e264d9810b82cab40bebfdebc547

See more details on using hashes here.

File details

Details for the file vision_aided_loss-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for vision_aided_loss-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 282ebc5df1ff212b2f16fcf073e9d6e9c6d1717a10b5a4e60b1e3f7beddcaf1b
MD5 36f5372b270e794c780a71635e531433
BLAKE2b-256 7ffda0f62997a6aea521054e276e515dc050b13b8894e2f07c4d8b4dee293d75

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page