Skip to main content

PyTorch implements a simple GAN neural network structure.

Project description

GAN-PyTorch

Update (Feb 16, 2020)

Now you can install this library directly using pip!

$ pip3 install --upgrade gan_pytorch

Update (January 29, 2020)

The mnist and fmnist models are now available. Their usage is identical to the other models:

from gan_pytorch import Generator
model = Generator.from_pretrained('g-mnist') 

Overview

This repository contains an op-for-op PyTorch reimplementation of Generative Adversarial Networks.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects. This implementation is a work in progress -- new features are currently being implemented.

At the moment, you can easily:

  • Load pretrained Generate models
  • Use Generate models for extended dataset

Upcoming features: In the next few days, you will be able to:

  • Quickly finetune an Generate on your own dataset
  • Export Generate models for production

Table of contents

  1. About Generative Adversarial Networks
  2. Model Description
  3. Installation
  4. Usage
  5. Contributing

About Generative Adversarial Networks

If you're new to GANs, here's an abstract straight from the paper:

We propose a new framework for estimating generative models via an adversarial process, in which we simultaneously train two models: a generative model G that captures the data distribution, and a discriminative model D that estimates the probability that a sample came from the training data rather than G. The training procedure for G is to maximize the probability of D making a mistake. This framework corresponds to a minimax two-player game. In the space of arbitrary functions G and D, a unique solution exists, with G recovering the training data distribution and D equal to 1/2 everywhere. In the case where G and D are defined by multilayer perceptrons, the entire system can be trained with backpropagation. There is no need for any Markov chains or unrolled approximate inference networks during either training or generation of samples. Experiments demonstrate the potential of the framework through qualitative and quantitative evaluation of the generated samples.

Model Description

We have two networks, G (Generator) and D (Discriminator).The Generator is a network for generating images. It receives a random noise z and generates images from this noise, which is called G(z).Discriminator is a discriminant network that discriminates whether an image is real. The input is x, x is a picture, and the output is D of x is the probability that x is a real picture, and if it's 1, it's 100% real, and if it's 0, it's not real.

Installation

Install from pypi:

$ pip3 install gan_pytorch

Install from source:

$ git clone https://github.com/lornatang/Generative-Adversarial-Networks
$ cd Generative-Adversarial-Networks
$ pip3 install -r requirements.txt
$ pip3 install -e .

Usage

Loading pretrained models

Load an Generative-Adversarial-Networks:

from gan_pytorch import Generator
model = Generator.from_name("g-mnist")

Load a pretrained Generative-Adversarial-Networks:

from gan_pytorch import Generator
model = Generator.from_pretrained("g-mnist")

Example: Extended dataset

As mentioned in the example, if you load the pre-trained weights of the MNIST dataset, it will create a new imgs directory and generate 64 random images in the imgs directory.

import os
import torch
import torchvision.utils as vutils
from gan_pytorch import Generator

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Generator.from_pretrained("g-mnist")
model.to(device)
# switch to evaluate mode
model.eval()

try:
  os.makedirs("./imgs")
except OSError:
  pass

with torch.no_grad():
  for i in range(64):
    noise = torch.randn(64, 100, device=device)
    fake = model(noise)
    vutils.save_image(fake.detach().cpu(), f"./imgs/fake_{i:04d}.png", normalize=True)
  print("The fake image has been generated!")

Example: Visual

cd $REPO$/framework
sh start.sh

Then open the browser and type in the browser address http://127.0.0.1:10000/. Enjoy it.

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gan_pytorch-0.4.0.tar.gz (8.5 kB view details)

Uploaded Source

Built Distribution

gan_pytorch-0.4.0-py2.py3-none-any.whl (14.4 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file gan_pytorch-0.4.0.tar.gz.

File metadata

  • Download URL: gan_pytorch-0.4.0.tar.gz
  • Upload date:
  • Size: 8.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.6.9

File hashes

Hashes for gan_pytorch-0.4.0.tar.gz
Algorithm Hash digest
SHA256 b817721c6438d852e10fb79ce8312fddf80a0f7f27d037e744d0db322dc342d1
MD5 e72075fed2fbbc58e210d1feaef98f9c
BLAKE2b-256 aed6bf36634790bb1e38d1bab214253891140a50b38b8515061f7bfa7e3dd211

See more details on using hashes here.

File details

Details for the file gan_pytorch-0.4.0-py2.py3-none-any.whl.

File metadata

  • Download URL: gan_pytorch-0.4.0-py2.py3-none-any.whl
  • Upload date:
  • Size: 14.4 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.6.9

File hashes

Hashes for gan_pytorch-0.4.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 b851a43de643ef005a2f1c02466224db900e7f8dd611e3f0e41d0c80f6ec5df1
MD5 4ecdf37d1a5e866f81b5b0bde0520aea
BLAKE2b-256 001d1ab111ad0f131fab913152cf25d34e3d10c913615d10e9fe9dfa6d171eae

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page