Skip to main content

The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs.

Project description

WassersteinGAN_GP-PyTorch

Update (Feb 21, 2020)

The mnist and fmnist models are now available. Their usage is identical to the other models:

from wgangp_pytorch import Generator
model = Generator.from_pretrained('g-mnist') 

Overview

This repository contains an op-for-op PyTorch reimplementation of Improved Training of Wasserstein GANs.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects. This implementation is a work in progress -- new features are currently being implemented.

At the moment, you can easily:

  • Load pretrained Generate models
  • Use Generate models for extended dataset

Upcoming features: In the next few days, you will be able to:

  • Quickly finetune an Generate on your own dataset
  • Export Generate models for production

Table of contents

  1. About Wasserstein GAN GP
  2. Model Description
  3. Installation
  4. Usage
  5. Contributing

About Wasserstein GAN GP

If you're new to Wasserstein GAN GP, here's an abstract straight from the paper:

Generative Adversarial Networks (GANs) are powerful generative models, but suffer from training instability. The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs, but sometimes can still generate only low-quality samples or fail to converge. We find that these problems are often due to the use of weight clipping in WGAN to enforce a Lipschitz constraint on the critic, which can lead to undesired behavior. We propose an alternative to clipping weights: penalize the norm of gradient of the critic with respect to its input. Our proposed method performs better than standard WGAN and enables stable training of a wide variety of GAN architectures with almost no hyperparameter tuning, including 101-layer ResNets and language models over discrete data. We also achieve high quality generations on CIFAR-10 and LSUN bedrooms.

Model Description

We have two networks, G (Generator) and D (Discriminator).The Generator is a network for generating images. It receives a random noise z and generates images from this noise, which is called G(z).Discriminator is a discriminant network that discriminates whether an image is real. The input is x, x is a picture, and the output is D of x is the probability that x is a real picture, and if it's 1, it's 100% real, and if it's 0, it's not real.

Installation

Install from pypi:

pip install wgangp_pytorch

Install from source:

git clone https://github.com/Lornatang/WassersteinGAN_GP-PyTorch.git
cd WassersteinGAN_gp-PyTorch
pip install -e .

Usage

Loading pretrained models

Load an Wasserstein GAN GP:

from wgangp_pytorch import Generator
model = Generator.from_name("g-mnist")

Load a pretrained Wasserstein GAN GP:

from wgangp_pytorch import Generator
model = Generator.from_pretrained("g-mnist")

Example: Extended dataset

As mentioned in the example, if you load the pre-trained weights of the MNIST dataset, it will create a new imgs directory and generate 64 random images in the imgs directory.

import os
import torch
import torchvision.utils as vutils
from wgangp_pytorch import Generator

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Generator.from_pretrained("g-mnist")
model.to(device)
# switch to evaluate mode
model.eval()

try:
    os.makedirs("./imgs")
except OSError:
    pass

with torch.no_grad():
    for i in range(64):
        noise = torch.randn(64, 100, device=device)
        fake = model(noise)
        vutils.save_image(fake.detach(), f"./imgs/fake_{i:04d}.png", normalize=True)
    print("The fake image has been generated!")

Example: Visual

cd $REPO$/framework
sh start.sh

Then open the browser and type in the browser address http://127.0.0.1:10003/. Enjoy it.

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wgangp_pytorch-0.1.2.tar.gz (7.3 kB view details)

Uploaded Source

Built Distribution

wgangp_pytorch-0.1.2-py2.py3-none-any.whl (11.6 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file wgangp_pytorch-0.1.2.tar.gz.

File metadata

  • Download URL: wgangp_pytorch-0.1.2.tar.gz
  • Upload date:
  • Size: 7.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for wgangp_pytorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 1e219497c3bee9c06b13da8963d9492c4a21e3adcb0842b8db98faa10db3ce6d
MD5 0e92ed34582b451b48029b6981c83810
BLAKE2b-256 bfdc876c4d38fc0b04cd4f14027119e7410feadc1b0fc896fd05e047c572b6c7

See more details on using hashes here.

File details

Details for the file wgangp_pytorch-0.1.2-py2.py3-none-any.whl.

File metadata

  • Download URL: wgangp_pytorch-0.1.2-py2.py3-none-any.whl
  • Upload date:
  • Size: 11.6 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for wgangp_pytorch-0.1.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 d977f81f167d665f3123e76ed87a570780de444345a0ccbb3fdc7f5f011b0d3f
MD5 c8c7e2aed6a6f10611b5fec393ccf184
BLAKE2b-256 0387b95067bbe54b95f0b30f8a67e2ae4a27a9c8333bdafee8d49675f5dfafd7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page