A library to easily train various existing GANs in PyTorch.
Project description
vegans
A library to easily train various existing GANs (and other generative models) in PyTorch.
This library targets mainly GAN users, who want to use existing GAN training techniques with their own generators/discriminators. However researchers may also find the GenerativeModel base class useful for quicker implementation of new GAN training techniquess.
The focus is on simplicity and providing reasonable defaults.
How to install
You need python 3.7 or above. Then:
pip install vegans
How to use
The basic idea is that the user provides discriminator / critic and generator networks (additionally an encoder if needed), and the library takes care of training them in a selected GAN setting. To get familiar with the library:
- Read through this README.md file
- Check out the notebooks (00 to 04)
- If you want to create your own GAN algorithms, check out the notebooks 05 to 07
- Look at the example code snippets
vegans implements two types of generative models: Unsupervised and Supervised (examples given below). Unsupervised algorithms are used when no labels exist for the data you want to generate, for example in cases where it is too tedious or infeasible to generate labels for every output. The disadvantage is that after training the generation process will be unsupervised as well, meaning you have (in most cases) little control over which type of output is generated. Supervised algorithms on the other hand require you to specify the input dimension of the label (y_dim
) and provide labels during training. All algorithms requiring labels are implemented as "ConditionalGAN" (e.g. VanillaGAN
does not take labels, whereas ConditionalVanillaGAN
does). These algorithms enable you to generate a specific output conditonal on a certain input.
In the case of handwritten digit generation (MNIST
) a supervised algorithm let's you produce images of a certain number that you control (e.g. images of zeros). Supervised methods are also required for text-to-image, image-to-text, image-to-image, text-to-audio, etc. translation tasks, because output should be generated conditional on an input (what does the image look like given a specific text snippet). Currently, the encoding of the conditional vector (label, text, audio, ...) has to be handled on the user side.
An interesting middle ground is take by the InfoGAN
algorithm which tries to learn the labels itself during training. We refer to the original paper for more detailed information on the algorithm, but the vegans API for this method works similar to any other GAN. A conditional version exists, called ConditionalInfoGAN
where label information can be provided but additional features are learned during training.
You can currently use the following generative models:
AAE
: Adversarial Auto-EncoderBicycleGAN
: BicycleGANEBGAN
: Energy-Based GANInfoGAN
: Information-Based GANKLGAN
: Kullback-Leib GANLRGAN
: Latent-Regressor GANLSGAN
: Least-Squares GANVAEGAN
: Variational Auto-Encoder GANVanillaGAN
: Classic minimax GAN, in its non-saturated versionVanillaVAE
: Variational Auto-EncoderWassersteinGAN
: Wasserstein GANWassersteinGANGP
: Wasserstein GAN with gradient penalty
All current generative model implementations come with a conditional variant to allow for the usage of training labels to produce specific outputs:
ConditionalAEE
ConditionalBicycleGAN
ConditionalEBGAN
- ...
ConditionalCycleGAN
ConditionalPix2Pix
This can either be used to pass a one hot encoded vector to predict a specific label (generate a certain number in case of mnist: example_mnist_conditional.py or 03_mnist-conditional.ipynb) or it can also be a full image (when for example trying to rotate an image: example_image_to_image.py or 04_mnist-image-to-image.ipynb).
Models can either be passed as torch.nn.Sequential
objects or by defining custom architectures, see example_input_formats.py.
Also look at the jupyter notebooks for better visualized examples and how to use the library.
Unsupervised Learning Example
from vegans.GAN import VanillaGAN
import vegans.utils.utils as utils
import vegans.utils.loading as loading
# Data preparation
datapath = "./data/" # root path to data, if does not exist will be downloaded into this folder
X_train, y_train, X_test, y_test = loading.load_data(datapath, which="mnist", download=True)
X_train = X_train.reshape((-1, 1, 32, 32)) # required shape
X_test = X_test.reshape((-1, 1, 32, 32))
x_dim = X_train.shape[1:] # [height, width, nr_channels]
z_dim = 64
# Define your own architectures here. You can use a Sequential model or an object
# inheriting from torch.nn.Module. Here, a default model for mnist is loaded.
generator = loading.load_generator(x_dim=x_dim, z_dim=z_dim, which="example")
discriminator = loading.load_adversary(x_dim=x_dim, z_dim=z_dim, adv_type="Discriminator", which="example")
gan = VanillaGAN(
generator=generator, adversary=discriminator,
z_dim=z_dim, x_dim=x_dim, folder=None
)
gan.summary() # optional, shows architecture
# Training
gan.fit(X_train, enable_tensorboard=False)
# Vizualise results
images, losses = gan.get_training_results()
images = images.reshape(-1, *images.shape[2:]) # remove nr_channels for plotting
utils.plot_images(images)
utils.plot_losses(losses)
# Sample new images, you can also pass a specific noise vector
samples = gan.generate(n=36)
samples = samples.reshape(-1, *samples.shape[2:]) # remove nr_channels for plotting
utils.plot_images(samples)
Supervised / Conditional Learning Example
import torch
import numpy as np
import vegans.utils.utils as utils
import vegans.utils.loading as loading
from vegans.GAN import ConditionalVanillaGAN
# Data preparation
datapath = "./data/" # root path to data, if does not exist will be downloaded into this folder
X_train, y_train, X_test, y_test = loading.load_data(datapath, which="mnist", download=True)
X_train = X_train.reshape((-1, 1, 32, 32)) # required shape
X_test = X_test.reshape((-1, 1, 32, 32))
nb_classes = len(set(y_train))
y_train = np.eye(nb_classes)[y_train.reshape(-1)]
y_test = np.eye(nb_classes)[y_test.reshape(-1)]
x_dim = X_train.shape[1:] # [nr_channels, height, width]
y_dim = y_train.shape[1:]
z_dim = 64
# Define your own architectures here. You can use a Sequential model or an object
# inheriting from torch.nn.Module. Here, a default model for mnist is loaded.
generator = loading.load_generator(x_dim=x_dim, z_dim=z_dim, y_dim=y_dim, which="mnist")
discriminator = loading.load_adversary(x_dim=x_dim, z_dim=z_dim, y_dim=y_dim, adv_type="Discriminator", which="mnist")
gan = ConditionalVanillaGAN(
generator=generator, adversary=discriminator,
z_dim=z_dim, x_dim=x_dim, y_dim=y_dim,
folder=None, # optional
optim={"Generator": torch.optim.RMSprop, "Adversary": torch.optim.Adam}, # optional
optim_kwargs={"Generator": {"lr": 0.0001}, "Adversary": {"lr": 0.0001}}, # optional
fixed_noise_size=32, # optional
device=None, # optional
ngpu=0 # optional
)
gan.summary() # optional, shows architecture
# Training
gan.fit(
X_train, y_train, X_test, y_test,
epochs=5, # optional
batch_size=32, # optional
steps={"Generator": 1, "Adversary": 2}, # optional, train generator once and discriminator twice on every mini-batch
print_every="0.1e", # optional, prints progress 10 times per epoch
# (might also be integer input indicating number of mini-batches)
save_model_every=None, # optional
save_images_every=None, # optional
save_losses_every="0.1e", # optional, save losses 10 times per epoch in internal losses dictionary used to generate
# plots during and after training
enable_tensorboard=False # optional, if true all progress is additionally saved in tensorboard subdirectory
)
# Vizualise results
images, losses = gan.get_training_results()
images = images.reshape(-1, *images.shape[2:]) # remove nr_channels for plotting
utils.plot_images(images, labels=np.argmax(gan.fixed_labels.cpu().numpy(), axis=1))
utils.plot_losses(losses)
# Generate specific label, for example "2"
label = np.array([[0, 0, 1, 0, 0, 0, 0, 0, 0,0 ]])
image = gan(y=label)
utils.plot_images(image, labels=["2"])
Slightly More Details:
Constructor arguments
All of the generative model objects inherit from a AbstractGenerativeModel
base class. and allow for the following input in the constructor.
optim
: The optimizer for all networks used during training. IfNone
a default optimizer (probably eithertorch.optim.Adam
ortorch.optim.RMSprop
) is chosen by the specific model. Adict
type with appropriate keys can be passed to specify different optimizers for different networks, for example{"Generator": torch.optim.Adam}
.optim_kwargs
: The optimizer keyword arguments. Adict
type with appropriate keys can be passed to specify different optimizer keyword arguments for different networks, for example{"Generator": {"lr": 0.001}}
.feature_layer
: If not None, it should be a layer of the discriminator or critic. The output of this layer is used to compute the mean squared error between the real and fake samples, i.e. it uses the feature loss. The existing GAN loss (often Binary cross-entropy) is overwritten.fixed_noise_size
: The number of samples to save (from fixed noise vectors). These are saved within tensorboard (ifenable_tensorboard=True
during fitting) and in theModel/images
subfolder.device
: "cuda" (GPU) or "cpu" depending on the available resources.ngpu
: Number of gpus used during trainingfolder
: Folder which will contain all results of the network (architecture, model.torch, images, loss plots, etc.). An existing folder will never be deleted or overwritten. If the folder already exists a new folder will be created with the given name + current time stamp.secure
: By default, vegans performs plenty of checks on inputs and outputs for all networks (For exampleencoder.output_size==z_dim
,generator.output_size==x_dim
orDiscriminator.last_layer==torch.nn.Sigmoid
). For some use cases these checks might be too restrictive. Ifsecure=False
vegans will perform only the most basic checks to run. Of course, if there are shape mismatches torch itself will still complain.
fit() arguments
The fit function takes the following optional arguments:
epochs
: Number of epochs to train the algorithm. Default: 5batch_size
: Size of one batch. Default: 32steps
: How often one network should be trained against another. Must bedict
type with appropriate names. E.g., for theWassersteinGAN
the dictionary could be{"Generator": 1, "Adversary": 5}
, indicating that the adversary should be trained five times on every mini-batch while the generator is trained once. The keys of the dictionary are fixed by the specified algorithm (here ["Generator", "Adversary"], for BicycleGAN would be ["Generator", "Adversary", "Encoder"] ). An appropriate error is raised if wrong keys are passed. The possible names should be obvious from the constructor of every algorithm but a wrong dictionary, e.g. {"Genrtr": 1}, can be passed consciously to receive a list of correct and available key values.print_every
: Determines after how many batches a message should be printed to the console informing about the current state of training. String indicating fraction or multiples of epoch can be given. I.e. "0.25e" = four times per epoch, "2e" after two epochs. Default: 100save_model_every
: Determines after how many batches the model should be saved. String indicating fraction or multiples of epoch can be given. I.e. "0.25e" = four times per epoch, "2e" after two epochs. Models will be saved in subdirectoryfolder
+"/models" (folder
specified in the constructor, see above in Constructor arguments). Default: Nonesave_images_every
: Determines after how many batches sample images and loss curves should be saved. String indicating fraction or multiples of epoch can be given. I.e. "0.25e" = four times per epoch, "2e" after two epochs. Images will be saved in subdirectoryfolder
+"/images" (folder
specified in the constructor, see above in Constructor arguments). Default: Nonesave_losses_every
: Determines after how many batches the losses should be calculated and saved. Figure is shown aftersave_images_every
. String indicating fraction or multiples of epoch can be given. I.e. "0.25e" = four times per epoch, "2e" after two epochs. Default: "1e"enable_tensorboard
: Tensorboard information for losses, samples and training time will be saved in subdirectoryfolder
+"/tensorboard" (folder
specified in the constructor, see above in Constructor arguments). Default: False
All of the generative model objects inherit from a AbstractGenerativeModel
base class. When building any such GAN, you must pass generator / decoder as well as discriminator / encoder networks (some torch.nn.Module
), as well as a the dimensions of the latent space z_dim
and input dimension of the images x_dim
.
Generative Model methods:
-
generate(z=None, n=None)
/generate(y, z=None, n=None)
: Generate samples from noise vector or generate "n" samples. -
get_hyperparameters()
: Get dictionary containing important hyperparameters. -
get_losses(by_epoch=False, agg=None)
: Return a dictionary of logged losses. Number of elements determined by thesave_losses_every
parameter passed to thefit
method. -
get_number_params()
: Get the number of parameters per network. -
get_training_results(by_epoch=False, agg=None)
: Returns the samples generated from thefixed_noise
attribute and the logged losses. -
load(path)
: Load a trained model. -
predict(x)
: Use the adversary to predict the realness of an image. -
sample(n)
: Sample a noise vector of size n. -
save(name=None)
: Save the model. -
summary(save=False)
: Print a summary of the model containing the number of parameters and general structure. -
to(device)
: Map all networks to a common device. Should be done before training.
Generative model attributes:
feature_layer
: Function to calculate feature loss with. If None no feature loss is computed. If not None the feature loss overwrites the "normal" generator loss.fixed_noise
, (fixed_noise_labels
): Noise vector sampled before training and used to generate the images in the created subdirectory (ifsave_images_every
in thefit
mehtod is not None). Also used to produce the results fromget_training_results()
.folder
: Folder where all information belonging to GAN is stored. This includes- Models in the
folder/models
subdirectory ifsave_model_every
is not None in thefit()
method. - Images in the
folder/images
subdirectory ifsave_images_every
is not None in thefit()
method. - Tensorboard data in the
folder/tensorboard
subdirectory ifenable_tensorboard
is True in thefit()
method. - Loss in the
folder/losses.png
ifsave_losses_every
is not None infit()
method. - Loss in the
folder/summary.txt
ifsummary(save=True)
called.
- Models in the
images_produced
: Flag (True / False) if images are the target of the generator.total_training_time
,batch_training_times
: Time needed for training.x_dim
,z_dim
, (y_dim
): Input dimensions.training
: Flag (True / False) if model is in training or evaluation mode. Normally the flag is False and is automatically set to True in the main training loop.
Attentive readers might notice that in most places we try to talk about "Generative Models" instead of "Generative Adversarial Networks", because vegans
currently also supports the Variational Autoencoder algorithm (VanillaVAE
) which are their own method of generating data. However, you can interpret the decoder of the VAE equivalently to a generator of a GAN. Both take the latent space (and sometimes labels) as input and transform them to the desired output space. In an abstract sense the encoder of the VAE also corresponds to the discriminator of the GAN as both aim to condense their input from the image space to a lower dimensional latent dimension. These abstract commonalities are used in the AbstractGenerativeModel
to unify both types of algorithms and provide a largely similar API.
In the future we also plan to implement different VAE algorithms to have all generative models in one place but for now the library is focused on GAN algorithms.
If you are researching new generative model training algorithms, you may find it useful to inherit from the AbstractGenerativeModel
or AbstractConditionalGenerativeModel
base class.
Learn more:
Currently the best way to learn more about how to use vegans is to have a look at the example notebooks. You can start with this simple example showing how to sample from a univariate Gaussian using a GAN. Alternatively, can run example scripts.
Contribute
PRs and suggestions are welcome. Look here for more details on the setup.
Credits
Some of the code has been inspired by some existing GAN implementations:
- https://github.com/eriklindernoren/PyTorch-GAN
- https://github.com/martinarjovsky/WassersteinGAN
- https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
Some Results
All this results should be taken with a grain of salt. They were not extensively fine tuned in any way, so better results for individual networks are possible for sure. More time training as well as more regularization could most certainly improve results. All of these results were generated by running the example_conditional.py program in the examples folder. Especially the Variational Autoencoder would perform better if we increased it's number of parameters to a comparable level.
Network | MNIST Result |
---|---|
Cond. BicycleGAN | |
Cond. EBGAN | |
Cond. InfoGAN | |
Cond. KLGAN | |
Cond. LRGAN | |
Cond. Pix2Pix | |
Cond. VAEGAN | |
Cond. VanillaGAN | |
Cond. WassersteinGAN | |
Cond. WassersteinGANGP | |
Cond. VAE |
TODO
-
GAN Implementations (sorted by priority)
-
Layers
- Minibatch discrimination
- Instance normalization
-
Other
-
New links to correct github files
-
Core Improvements:
- Hide feature_layer, secure in **kwargs
- Generalize conditional networks and only let them handle correct concatenation
- Abstraction for GDE networks (1v1+1)
- Make it more PEP conform
- Make _default_optimizer not abstract
-
Perceptual Loss here
-
Interpolation
-
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file vegans-0.2.1.tar.gz
.
File metadata
- Download URL: vegans-0.2.1.tar.gz
- Upload date:
- Size: 13.0 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.8.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7176d1c944314601c311157d81b4846639068388d3ac2d698efbd984b6c6fec6 |
|
MD5 | 165076a73039bdb87332f601e3c2e21e |
|
BLAKE2b-256 | 6ca7c8f8a7e034734332a5d655a97fdb6dcc1aff274656397cad071351591b7f |
File details
Details for the file vegans-0.2.1-py3-none-any.whl
.
File metadata
- Download URL: vegans-0.2.1-py3-none-any.whl
- Upload date:
- Size: 118.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.8.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ab9daffa257465d3000737a70f4705903b905738856c96c69589320cc082a1f8 |
|
MD5 | eea9185035e0e143a4910e9eeaca6fb8 |
|
BLAKE2b-256 | 086d3b3a84a13ad3fdd12b46f1537cd72e5fa6d34483217211bd09b2a11159f1 |