Skip to main content

A library to easily train various existing GANs in PyTorch.

Project description

vegans

A library to easily train various existing GANs (and other generative models) in PyTorch.

This library targets mainly GAN users, who want to use existing GAN training techniques with their own generators/discriminators. However researchers may also find the GenerativeModel base class useful for quicker implementation of new GAN training techniquess.

The focus is on simplicity and providing reasonable defaults.

How to install

You need python 3.7 or above. Then: pip install vegans

How to use

The basic idea is that the user provides discriminator / critic and generator networks (additionally an encoder if needed), and the library takes care of training them in a selected GAN setting. To get familiar with the library:

vegans implements two types of generative models: Unsupervised and Supervised (examples given below). Unsupervised algorithms are used when no labels exist for the data you want to generate, for example in cases where it is too tedious or infeasible to generate labels for every output. The disadvantage is that after training the generation process will be unsupervised as well, meaning you have (in most cases) little control over which type of output is generated. Supervised algorithms on the other hand require you to specify the input dimension of the label (y_dim) and provide labels during training. All algorithms requiring labels are implemented as "ConditionalGAN" (e.g. VanillaGAN does not take labels, whereas ConditionalVanillaGAN does). These algorithms enable you to generate a specific output conditonal on a certain input.

In the case of handwritten digit generation (MNIST) a supervised algorithm let's you produce images of a certain number that you control (e.g. images of zeros). Supervised methods are also required for text-to-image, image-to-text, image-to-image, text-to-audio, etc. translation tasks, because output should be generated conditional on an input (what does the image look like given a specific text snippet). Currently, the encoding of the conditional vector (label, text, audio, ...) has to be handled on the user side.

An interesting middle ground is take by the InfoGAN algorithm which tries to learn the labels itself during training. We refer to the original paper for more detailed information on the algorithm, but the vegans API for this method works similar to any other GAN. A conditional version exists, called ConditionalInfoGAN where label information can be provided but additional features are learned during training.

You can currently use the following generative models:

All current generative model implementations come with a conditional variant to allow for the usage of training labels to produce specific outputs:

  • ConditionalAEE
  • ConditionalBicycleGAN
  • ConditionalEBGAN
  • ...
  • ConditionalCycleGAN
  • ConditionalPix2Pix

This can either be used to pass a one hot encoded vector to predict a specific label (generate a certain number in case of mnist: example_mnist_conditional.py or 03_mnist-conditional.ipynb) or it can also be a full image (when for example trying to rotate an image: example_image_to_image.py or 04_mnist-image-to-image.ipynb).

Models can either be passed as torch.nn.Sequential objects or by defining custom architectures, see example_input_formats.py.

Also look at the jupyter notebooks for better visualized examples and how to use the library.

Unsupervised Learning Example

from vegans.GAN import VanillaGAN
import vegans.utils.utils as utils
import vegans.utils.loading as loading

# Data preparation
datapath =  "./data/" # root path to data, if does not exist will be downloaded into this folder
X_train, y_train, X_test, y_test = loading.load_data(datapath, which="mnist", download=True)
X_train = X_train.reshape((-1, 1, 32, 32)) # required shape
X_test = X_test.reshape((-1, 1, 32, 32))
x_dim = X_train.shape[1:] # [height, width, nr_channels]
z_dim = 64

# Define your own architectures here. You can use a Sequential model or an object
# inheriting from torch.nn.Module. Here, a default model for mnist is loaded.
generator = loading.load_generator(x_dim=x_dim, z_dim=z_dim, which="example")
discriminator = loading.load_adversary(x_dim=x_dim, z_dim=z_dim, adv_type="Discriminator", which="example")

gan = VanillaGAN(
    generator=generator, adversary=discriminator,
    z_dim=z_dim, x_dim=x_dim, folder=None
)
gan.summary() # optional, shows architecture

# Training
gan.fit(X_train, enable_tensorboard=False)

# Vizualise results
images, losses = gan.get_training_results()
images = images.reshape(-1, *images.shape[2:]) # remove nr_channels for plotting
utils.plot_images(images)
utils.plot_losses(losses)

# Sample new images, you can also pass a specific noise vector
samples = gan.generate(n=36)
samples = samples.reshape(-1, *samples.shape[2:]) # remove nr_channels for plotting
utils.plot_images(samples)

Supervised / Conditional Learning Example

import torch
import numpy as np
import vegans.utils.utils as utils
import vegans.utils.loading as loading
from vegans.GAN import ConditionalVanillaGAN

# Data preparation
datapath =  "./data/" # root path to data, if does not exist will be downloaded into this folder
X_train, y_train, X_test, y_test = loading.load_data(datapath, which="mnist", download=True)
X_train = X_train.reshape((-1, 1, 32, 32)) # required shape
X_test = X_test.reshape((-1, 1, 32, 32))
nb_classes = len(set(y_train))
y_train = np.eye(nb_classes)[y_train.reshape(-1)]
y_test = np.eye(nb_classes)[y_test.reshape(-1)]

x_dim = X_train.shape[1:] # [nr_channels, height, width]
y_dim = y_train.shape[1:]
z_dim = 64

# Define your own architectures here. You can use a Sequential model or an object
# inheriting from torch.nn.Module. Here, a default model for mnist is loaded.
generator = loading.load_generator(x_dim=x_dim, z_dim=z_dim, y_dim=y_dim, which="mnist")
discriminator = loading.load_adversary(x_dim=x_dim, z_dim=z_dim, y_dim=y_dim, adv_type="Discriminator", which="mnist")

gan = ConditionalVanillaGAN(
    generator=generator, adversary=discriminator,
    z_dim=z_dim, x_dim=x_dim, y_dim=y_dim,
    folder=None, # optional
    optim={"Generator": torch.optim.RMSprop, "Adversary": torch.optim.Adam}, # optional
    optim_kwargs={"Generator": {"lr": 0.0001}, "Adversary": {"lr": 0.0001}}, # optional
    fixed_noise_size=32, # optional
    device=None, # optional
    ngpu=0 # optional

)
gan.summary() # optional, shows architecture

# Training
gan.fit(
    X_train, y_train, X_test, y_test,
    epochs=5, # optional
    batch_size=32, # optional
    steps={"Generator": 1, "Adversary": 2}, # optional, train generator once and discriminator twice on every mini-batch
    print_every="0.1e", # optional, prints progress 10 times per epoch
    					# (might also be integer input indicating number of mini-batches)
    save_model_every=None, # optional
    save_images_every=None, # optional
    save_losses_every="0.1e", # optional, save losses 10 times per epoch in internal losses dictionary used to generate
    						  # plots during and after training
    enable_tensorboard=False # optional, if true all progress is additionally saved in tensorboard subdirectory
)

# Vizualise results
images, losses = gan.get_training_results()
images = images.reshape(-1, *images.shape[2:]) # remove nr_channels for plotting
utils.plot_images(images, labels=np.argmax(gan.fixed_labels.cpu().numpy(), axis=1))
utils.plot_losses(losses)

# Generate specific label, for example "2"
label = np.array([[0, 0, 1, 0, 0, 0, 0, 0, 0,0 ]])
image = gan(y=label)
utils.plot_images(image, labels=["2"])

Slightly More Details:

Constructor arguments

All of the generative model objects inherit from a AbstractGenerativeModel base class. and allow for the following input in the constructor.

  • optim: The optimizer for all networks used during training. If None a default optimizer (probably either torch.optim.Adam or torch.optim.RMSprop) is chosen by the specific model. A dict type with appropriate keys can be passed to specify different optimizers for different networks, for example {"Generator": torch.optim.Adam}.
  • optim_kwargs: The optimizer keyword arguments. A dict type with appropriate keys can be passed to specify different optimizer keyword arguments for different networks, for example {"Generator": {"lr": 0.001}}.
  • feature_layer: If not None, it should be a layer of the discriminator or critic. The output of this layer is used to compute the mean squared error between the real and fake samples, i.e. it uses the feature loss. The existing GAN loss (often Binary cross-entropy) is overwritten.
  • fixed_noise_size: The number of samples to save (from fixed noise vectors). These are saved within tensorboard (if enable_tensorboard=True during fitting) and in the Model/images subfolder.
  • device: "cuda" (GPU) or "cpu" depending on the available resources.
  • ngpu: Number of gpus used during training
  • folder: Folder which will contain all results of the network (architecture, model.torch, images, loss plots, etc.). An existing folder will never be deleted or overwritten. If the folder already exists a new folder will be created with the given name + current time stamp.
  • secure: By default, vegans performs plenty of checks on inputs and outputs for all networks (For example encoder.output_size==z_dim, generator.output_size==x_dim or Discriminator.last_layer==torch.nn.Sigmoid). For some use cases these checks might be too restrictive. If secure=False vegans will perform only the most basic checks to run. Of course, if there are shape mismatches torch itself will still complain.

fit() arguments

The fit function takes the following optional arguments:

  • epochs: Number of epochs to train the algorithm. Default: 5
  • batch_size: Size of one batch. Default: 32
  • steps: How often one network should be trained against another. Must be dict type with appropriate names. E.g., for the WassersteinGAN the dictionary could be {"Generator": 1, "Adversary": 5}, indicating that the adversary should be trained five times on every mini-batch while the generator is trained once. The keys of the dictionary are fixed by the specified algorithm (here ["Generator", "Adversary"], for BicycleGAN would be ["Generator", "Adversary", "Encoder"] ). An appropriate error is raised if wrong keys are passed. The possible names should be obvious from the constructor of every algorithm but a wrong dictionary, e.g. {"Genrtr": 1}, can be passed consciously to receive a list of correct and available key values.
  • print_every: Determines after how many batches a message should be printed to the console informing about the current state of training. String indicating fraction or multiples of epoch can be given. I.e. "0.25e" = four times per epoch, "2e" after two epochs. Default: 100
  • save_model_every: Determines after how many batches the model should be saved. String indicating fraction or multiples of epoch can be given. I.e. "0.25e" = four times per epoch, "2e" after two epochs. Models will be saved in subdirectory folder+"/models" (folder specified in the constructor, see above in Constructor arguments). Default: None
  • save_images_every: Determines after how many batches sample images and loss curves should be saved. String indicating fraction or multiples of epoch can be given. I.e. "0.25e" = four times per epoch, "2e" after two epochs. Images will be saved in subdirectory folder+"/images" (folder specified in the constructor, see above in Constructor arguments). Default: None
  • save_losses_every: Determines after how many batches the losses should be calculated and saved. Figure is shown after save_images_every . String indicating fraction or multiples of epoch can be given. I.e. "0.25e" = four times per epoch, "2e" after two epochs. Default: "1e"
  • enable_tensorboard: Tensorboard information for losses, samples and training time will be saved in subdirectory folder+"/tensorboard" (folder specified in the constructor, see above in Constructor arguments). Default: False

All of the generative model objects inherit from a AbstractGenerativeModel base class. When building any such GAN, you must pass generator / decoder as well as discriminator / encoder networks (some torch.nn.Module), as well as a the dimensions of the latent space z_dim and input dimension of the images x_dim.

Generative Model methods:

  • generate(z=None, n=None) / generate(y, z=None, n=None): Generate samples from noise vector or generate "n" samples.

  • get_hyperparameters(): Get dictionary containing important hyperparameters.

  • get_losses(by_epoch=False, agg=None): Return a dictionary of logged losses. Number of elements determined by the save_losses_every parameter passed to the fit method.

  • get_number_params(): Get the number of parameters per network.

  • get_training_results(by_epoch=False, agg=None): Returns the samples generated from the fixed_noise attribute and the logged losses.

  • load(path): Load a trained model.

  • predict(x): Use the adversary to predict the realness of an image.

  • sample(n): Sample a noise vector of size n.

  • save(name=None): Save the model.

  • summary(save=False): Print a summary of the model containing the number of parameters and general structure.

  • to(device): Map all networks to a common device. Should be done before training.

Generative model attributes:

  • feature_layer: Function to calculate feature loss with. If None no feature loss is computed. If not None the feature loss overwrites the "normal" generator loss.
  • fixed_noise, (fixed_noise_labels): Noise vector sampled before training and used to generate the images in the created subdirectory (if save_images_every in the fit mehtod is not None). Also used to produce the results from get_training_results().
  • folder: Folder where all information belonging to GAN is stored. This includes
    • Models in the folder/models subdirectory if save_model_every is not None in thefit() method.
    • Images in the folder/images subdirectory if save_images_every is not None in the fit() method.
    • Tensorboard data in the folder/tensorboard subdirectory if enable_tensorboard is True in the fit() method.
    • Loss in the folder/losses.png if save_losses_every is not None in fit() method.
    • Loss in the folder/summary.txt if summary(save=True)called.
  • images_produced: Flag (True / False) if images are the target of the generator.
  • total_training_time, batch_training_times: Time needed for training.
  • x_dim, z_dim, (y_dim): Input dimensions.
  • training: Flag (True / False) if model is in training or evaluation mode. Normally the flag is False and is automatically set to True in the main training loop.

Attentive readers might notice that in most places we try to talk about "Generative Models" instead of "Generative Adversarial Networks", because vegans currently also supports the Variational Autoencoder algorithm (VanillaVAE) which are their own method of generating data. However, you can interpret the decoder of the VAE equivalently to a generator of a GAN. Both take the latent space (and sometimes labels) as input and transform them to the desired output space. In an abstract sense the encoder of the VAE also corresponds to the discriminator of the GAN as both aim to condense their input from the image space to a lower dimensional latent dimension. These abstract commonalities are used in the AbstractGenerativeModel to unify both types of algorithms and provide a largely similar API.

In the future we also plan to implement different VAE algorithms to have all generative models in one place but for now the library is focused on GAN algorithms.

If you are researching new generative model training algorithms, you may find it useful to inherit from the AbstractGenerativeModel or AbstractConditionalGenerativeModel base class.

Learn more:

Currently the best way to learn more about how to use vegans is to have a look at the example notebooks. You can start with this simple example showing how to sample from a univariate Gaussian using a GAN. Alternatively, can run example scripts.

Contribute

PRs and suggestions are welcome. Look here for more details on the setup.

Credits

Some of the code has been inspired by some existing GAN implementations:

Some Results

All this results should be taken with a grain of salt. They were not extensively fine tuned in any way, so better results for individual networks are possible for sure. More time training as well as more regularization could most certainly improve results. All of these results were generated by running the example_conditional.py program in the examples folder. Especially the Variational Autoencoder would perform better if we increased it's number of parameters to a comparable level.

Network MNIST Result
Cond. BicycleGAN MNIST
Cond. EBGAN MNIST
Cond. InfoGAN MNIST
Cond. KLGAN MNIST
Cond. LRGAN MNIST
Cond. Pix2Pix MNIST
Cond. VAEGAN MNIST
Cond. VanillaGAN MNIST
Cond. WassersteinGAN MNIST
Cond. WassersteinGANGP MNIST
Cond. VAE MNIST

TODO

  • GAN Implementations (sorted by priority)

    • BEGAN
    • WassersteinGAN SpectralNorm
    • Stacked GAN here
    • Progressive Growing GAN here
  • Layers

    • Minibatch discrimination
    • Instance normalization
  • Other

    • Core Improvements:

      • Hide feature_layer, secure in **kwargs
      • Make it more PEP conform
      • Make _default_optimizer not abstract
      • Windows installation issues
      • CI workflow
      • Create protected branches
      • Conda installation
      • Type annotations
      • Documentation website
      • build fancy examples
    • Perceptual Loss here

    • Interpolation

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vegans-0.3.0.tar.gz (13.1 MB view details)

Uploaded Source

Built Distribution

vegans-0.3.0-py3-none-any.whl (116.7 kB view details)

Uploaded Python 3

File details

Details for the file vegans-0.3.0.tar.gz.

File metadata

  • Download URL: vegans-0.3.0.tar.gz
  • Upload date:
  • Size: 13.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.2

File hashes

Hashes for vegans-0.3.0.tar.gz
Algorithm Hash digest
SHA256 5d8bc953b9775f5a3c6d683c1c99cabea85cf7c10d847b5ec8b4840991b0142b
MD5 d5df00aab7dcfb1e245c2c7266d67ede
BLAKE2b-256 a99d2051704956722ddb922260250ef6c44f41c077a1a994d8e40e7d4b6af69a

See more details on using hashes here.

File details

Details for the file vegans-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: vegans-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 116.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.2

File hashes

Hashes for vegans-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0eeffa2840865b705db44e3be235963198e99193b175893ae1760fdb4a7e59ba
MD5 cfd1d0f7451e24d0ae631989ceef1ef9
BLAKE2b-256 0cb4df04ea4fb0c79c3e878db2321722d39bed087c085034b2958bda021a2381

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page