Skip to main content

This is a new alogorithm named WGAN, an alternative to traditional GAN training!

Project description

WassersteinGAN-PyTorch

Update (Feb 21, 2020)

The mnist and fmnist models are now available. Their usage is identical to the other models:

from wgan_pytorch import Generator
model = Generator.from_pretrained('g-mnist') 

Overview

This repository contains an op-for-op PyTorch reimplementation of Wasserstein GAN.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects. This implementation is a work in progress -- new features are currently being implemented.

At the moment, you can easily:

  • Load pretrained Generate models
  • Use Generate models for extended dataset

Upcoming features: In the next few days, you will be able to:

  • Quickly finetune an Generate on your own dataset
  • Export Generate models for production

Table of contents

  1. About Wasserstein GAN
  2. Model Description
  3. Installation
  4. Usage
  5. Contributing

About Wasserstein GAN

If you're new to Wasserstein GAN, here's an abstract straight from the paper:

We introduce a new algorithm named WGAN, an alternative to traditional GAN training. In this new model, we show that we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches. Furthermore, we show that the corresponding optimization problem is sound, and provide extensive theoretical work highlighting the deep connections to other distances between distributions.

Model Description

We have two networks, G (Generator) and D (Discriminator).The Generator is a network for generating images. It receives a random noise z and generates images from this noise, which is called G(z).Discriminator is a discriminant network that discriminates whether an image is real. The input is x, x is a picture, and the output is D of x is the probability that x is a real picture, and if it's 1, it's 100% real, and if it's 0, it's not real.

Installation

Install from pypi:

pip install wgan_pytorch

Install from source:

git clone https://github.com/Lornatang/WassersteinGAN-PyTorch.git
cd WassersteinGAN-PyTorch
pip install -e .

Usage

Loading pretrained models

Load an Wasserstein GAN:

from wgan_pytorch import Generator
model = Generator.from_name("g-mnist")

Load a pretrained Wasserstein GAN:

from wgan_pytorch import Generator
model = Generator.from_pretrained("g-mnist")

Example: Extended dataset

As mentioned in the example, if you load the pre-trained weights of the MNIST dataset, it will create a new imgs directory and generate 64 random images in the imgs directory.

import os
import torch
import torchvision.utils as vutils
from wgan_pytorch import Generator

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Generator.from_pretrained("g-mnist")
model.to(device)
# switch to evaluate mode
model.eval()

try:
    os.makedirs("./imgs")
except OSError:
    pass

with torch.no_grad():
    for i in range(64):
        noise = torch.randn(64, 100, device=device)
        fake = model(noise)
        vutils.save_image(fake.detach(), f"./imgs/fake_{i:04d}.png", normalize=True)
    print("The fake image has been generated!")

Example: Visual

cd $REPO$/framework
sh start.sh

Then open the browser and type in the browser address http://127.0.0.1:10002/. Enjoy it.

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wgan_pytorch-0.1.2.tar.gz (6.8 kB view details)

Uploaded Source

Built Distribution

wgan_pytorch-0.1.2-py2.py3-none-any.whl (11.0 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file wgan_pytorch-0.1.2.tar.gz.

File metadata

  • Download URL: wgan_pytorch-0.1.2.tar.gz
  • Upload date:
  • Size: 6.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for wgan_pytorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 b58528e086d7a9b66a7ca9e0ad1e61779621c71dff044d188fd2e73973e56a9d
MD5 23d0092700d703b0db9190e8cd15adc1
BLAKE2b-256 eebd760187a6b7002d9372960f0739ab71506825d779afc2a750428132dbb6c6

See more details on using hashes here.

File details

Details for the file wgan_pytorch-0.1.2-py2.py3-none-any.whl.

File metadata

  • Download URL: wgan_pytorch-0.1.2-py2.py3-none-any.whl
  • Upload date:
  • Size: 11.0 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for wgan_pytorch-0.1.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 2eff831bb48a9f47319fa705bc70b9fe732bac14f393cbfd80fca1a3ac2cdac8
MD5 03a2441212d964b556c4c22322740fbd
BLAKE2b-256 90b19785cab1320444e0ba0403a37b8dbd9c68ea27f2bdaca054892c94661d41

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page