Skip to main content

In many domains of computer vision, generative adversarial networks (GANs) have achieved great success, among which the fam- ily of Wasserstein GANs (WGANs) is considered to be state-of-the-art due to the theoretical contributions and competitive qualitative performance

Project description

WassersteinGAN_DIV-PyTorch

Update (Feb 22, 2020)

The mnist and fmnist models are now available. Their usage is identical to the other models:

from wgandiv_pytorch import Generator
model = Generator.from_pretrained('g-mnist') 

Overview

This repository contains an op-for-op PyTorch reimplementation of Wasserstein Divergence for GANs.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects. This implementation is a work in progress -- new features are currently being implemented.

At the moment, you can easily:

  • Load pretrained Generate models
  • Use Generate models for extended dataset

Upcoming features: In the next few days, you will be able to:

  • Quickly finetune an Generate on your own dataset
  • Export Generate models for production

Table of contents

  1. About Wasserstein GAN DIV
  2. Model Description
  3. Installation
  4. Usage
  5. Contributing

About Wasserstein GAN DIV

If you're new to Wasserstein GAN DIV, here's an abstract straight from the paper:

In many domains of computer vision, generative adversarial networks (GANs) have achieved great success, among which the fam- ily of Wasserstein GANs (WGANs) is considered to be state-of-the-art due to the theoretical contributions and competitive qualitative performance. However, it is very challenging to approximate the k-Lipschitz constraint required by the Wasserstein-1 metric (W-met). In this paper, we propose a novel Wasserstein divergence (W-div), which is a relaxed version of W-met and does not require the k-Lipschitz constraint.As a concrete application, we introduce a Wasserstein divergence objective for GANs (WGAN-div), which can faithfully approximate W-div through optimization. Under various settings, including progressive growing training, we demonstrate the stability of the proposed WGAN-div owing to its theoretical and practical advantages over WGANs. Also, we study the quantitative and visual performance of WGAN-div on standard image synthesis benchmarks, showing the superior performance of WGAN-div compared to the state-of-the-art methods.

Model Description

We have two networks, G (Generator) and D (Discriminator).The Generator is a network for generating images. It receives a random noise z and generates images from this noise, which is called G(z).Discriminator is a discriminant network that discriminates whether an image is real. The input is x, x is a picture, and the output is D of x is the probability that x is a real picture, and if it's 1, it's 100% real, and if it's 0, it's not real.

Installation

Install from pypi:

pip install wgandiv_pytorch

Install from source:

git clone https://github.com/Lornatang/WassersteinGAN_DIV-PyTorch.git
cd WassersteinGAN_DIV-PyTorch
pip install -e .

Usage

Loading pretrained models

Load an Wasserstein GAN DIV:

from wgandiv_pytorch import Generator
model = Generator.from_name("g-mnist")

Load a pretrained Wasserstein GAN DIV:

from wgandiv_pytorch import Generator
model = Generator.from_pretrained("g-mnist")

Example: Extended dataset

As mentioned in the example, if you load the pre-trained weights of the MNIST dataset, it will create a new imgs directory and generate 64 random images in the imgs directory.

import os
import torch
import torchvision.utils as vutils
from wgandiv_pytorch import Generator

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Generator.from_pretrained("g-mnist")
model.to(device)
# switch to evaluate mode
model.eval()

try:
    os.makedirs("./imgs")
except OSError:
    pass

with torch.no_grad():
    for i in range(64):
        noise = torch.randn(64, 100, device=device)
        fake = model(noise)
        vutils.save_image(fake.detach(), f"./imgs/fake_{i:04d}.png", normalize=True)
    print("The fake image has been generated!")

Example: Visual

cd $REPO$/framework
sh start.sh

Then open the browser and type in the browser address http://127.0.0.1:10004/. Enjoy it.

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wgandiv_pytorch-0.1.2.tar.gz (7.8 kB view details)

Uploaded Source

Built Distribution

wgandiv_pytorch-0.1.2-py2.py3-none-any.whl (11.6 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file wgandiv_pytorch-0.1.2.tar.gz.

File metadata

  • Download URL: wgandiv_pytorch-0.1.2.tar.gz
  • Upload date:
  • Size: 7.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for wgandiv_pytorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 6926f4a076df08bee0c6f7b71058eb8c46d195bae7a0a84f35e5bba48698a17e
MD5 331c11788ccac55a2618a2fa94180e93
BLAKE2b-256 b60604b86bb9d165fda07bb6aabade58eaf2d841d3b5f74f03fc11c191d2894b

See more details on using hashes here.

File details

Details for the file wgandiv_pytorch-0.1.2-py2.py3-none-any.whl.

File metadata

  • Download URL: wgandiv_pytorch-0.1.2-py2.py3-none-any.whl
  • Upload date:
  • Size: 11.6 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for wgandiv_pytorch-0.1.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 e6d57f650e7cd2e29e59e30f8b2bb47782a6c0f82bbd4c509089bbe7d26af2cd
MD5 876e797344b3fd9b2d937f7077c14592
BLAKE2b-256 ca5f10bcb5ba6fcc79750cf8bf087f61593a6a68f1f5a20373de72a0f84e81e4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page