Skip to main content

Unifying Generative Autoencoders in Python

Project description

Python Documentation Status

Documentation

pythae

This library implements some of the most common (Variational) Autoencoder models. In particular it provides the possibility to perform benchmark experiments and comparisons by training the models with the same autoencoding neural network architecture. The feature make your own autoencoder allows you to train any of these models with your own data and own Encoder and Decoder neural networks.

Installation

To install the latest version of this library run the following using pip

$ pip install git+https://github.com/clementchadebec/benchmark_VAE.git

or alternatively you can clone the github repo to access to tests, tutorials and scripts.

$ git clone https://github.com/clementchadebec/benchmark_VAE.git

and install the library

$ cd benchmark_VAE
$ pip install -e .

Available Models

Below is the list of the models currently implemented in the library.

Models Training example Paper Official Implementation
Autoencoder (AE) Open In Colab
Variational Autoencoder (VAE) Open In Colab link
Beta Variational Autoencoder (BetaVAE) Open In Colab link
VAE with Linear Normalizing Flows (VAE_LinNF) Open In Colab link
VAE with Inverse Autoregressive Flows (VAE_IAF) Open In Colab link link
Disentangled Beta Variational Autoencoder (DisentangledBetaVAE) Open In Colab link
Disentangling by Factorising (FactorVAE) Open In Colab link
Beta-TC-VAE (BetaTCVAE) Open In Colab link link
Importance Weighted Autoencoder (IWAE) Open In Colab link link
VAE with perceptual metric similarity (MSSSIM_VAE) Open In Colab link
Wasserstein Autoencoder (WAE) Open In Colab link link
Info Variational Autoencoder (INFOVAE_MMD) Open In Colab link
VAMP Autoencoder (VAMP) Open In Colab link link
Hyperspherical VAE (SVAE) Open In Colab link link
Adversarial Autoencoder (Adversarial_AE) Open In Colab link
Variational Autoencoder GAN (VAEGAN) 🥗 Open In Colab link link
Vector Quantized VAE (VQVAE) Open In Colab link link
Hamiltonian VAE (HVAE) Open In Colab link link
Regularized AE with L2 decoder param (RAE_L2) Open In Colab link link
Regularized AE with gradient penalty (RAE_GP) Open In Colab link link
Riemannian Hamiltonian VAE (RHVAE) Open In Colab link link

See reconstruction and generation results for all aforementionned models

Available Samplers

Below is the list of the models currently implemented in the library.

Samplers Models Paper Official Implementation
Normal prior (NormalSampler) all models link
Gaussian mixture (GaussianMixtureSampler) all models link link
Two stage VAE sampler (TwoStageVAESampler) all VAE based models link link
Unit sphere uniform sampler (HypersphereUniformSampler) SVAE link link
VAMP prior sampler (VAMPSampler) VAMP link link
Manifold sampler (RHVAESampler) RHVAE link link
Masked Autoregressive Flow Sampler (MAFSampler) all models link link
Inverse Autoregressive Flow Sampler (IAFSampler) all models link link
PixelCNN (PixelCNNSampler) VQVAE link

Launching a model training

To launch a model training, you only need to call a TrainingPipeline instance.

>>> from pythae.pipelines import TrainingPipeline
>>> from pythae.models import VAE, VAEConfig
>>> from pythae.trainers import BaseTrainerConfig

>>> # Set up the training configuration
>>> my_training_config = BaseTrainerConfig(
...	output_dir='my_model',
...	num_epochs=50,
...	learning_rate=1e-3,
...	batch_size=200,
...	steps_saving=None
... )
>>> # Set up the model configuration 
>>> my_vae_config = model_config = VAEConfig(
...	input_dim=(1, 28, 28),
...	latent_dim=10
... )
>>> # Build the model
>>> my_vae_model = VAE(
...	model_config=my_vae_config
... )
>>> # Build the Pipeline
>>> pipeline = TrainingPipeline(
... 	training_config=my_training_config,
... 	model=my_vae_model
... )
>>> # Launch the Pipeline
>>> pipeline(
...	train_data=your_train_data, # must be torch.Tensor or np.array 
...	eval_data=your_eval_data # must be torch.Tensor or np.array
... )

At the end of training, the best model weights, model configuration and training configuration are stored in a final_model folder available in my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss (with my_model being the output_dir argument of the BaseTrainerConfig). If you further set the steps_saving argument to a certain value, folders named checkpoint_epoch_k containing the best model weights, optimizer, scheduler, configuration and training configuration at epoch k will also appear in my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss.

Lauching a training on benchmark datasets

We also provide a training script example here that can be used to train the models on benchmarks datasets (mnist, cifar10, celeba ...). The script can be launched with the following commandline

python training.py --dataset mnist --model_name ae --model_config 'configs/ae_config.json' --training_config 'configs/base_training_config.json'

See README.md for further details on this script

Launching data generation

Using the GeneationPipeline

The easiest way to launch a data generation from a trained model consists in using the built-in GenerationPipeline provided in Pythae. Say you want to generate 100 samples using a MAFSampler all you have to do is 1) relaod the trained model, 2) define the sampler's configuration and 3) create and launch the GenerationPipeline as follows

>>> from pythae.models import AutoModel
>>> from pythae.samplers import MAFSamplerConfig
>>> from pythae.pipelines import GenerationPipeline
>>> # Retrieve the trained model
>>> my_trained_vae = AutoModel.load_from_folder(
...	'path/to/your/trained/model'
... )
>>> my_sampler_config = MAFSamplerConfig(
...	n_made_blocks=2,
...	n_hidden_in_made=3,
...	hidden_size=128
... )
>>> # Build the pipeline
>>> pipe = GenerationPipeline(
...	model=my_trained_vae,
...	sampler_config=my_sampler_config
... )
>>> # Launch data generation
>>> generated_samples = pipe(
...	num_samples=args.num_samples,
...	return_gen=True, # If false returns nothing
...	train_data=train_data, # Needed to fit the sampler
...	eval_data=eval_data, # Needed to fit the sampler
...	training_config=BaseTrainerConfig(num_epochs=200) # TrainingConfig to use to fit the sampler
... )

Using the Samplers

Alternatively, you can launch the data generation process from a trained model directly with the sampler. For instance, to generate new data with your sampler, run the following.

>>> from pythae.models import AutoModel
>>> from pythae.samplers import NormalSampler
>>> # Retrieve the trained model
>>> my_trained_vae = AutoModel.load_from_folder(
...	'path/to/your/trained/model'
... )
>>> # Define your sampler
>>> my_samper = NormalSampler(
...	model=my_trained_vae
... )
>>> # Generate samples
>>> gen_data = my_samper.sample(
...	num_samples=50,
...	batch_size=10,
...	output_dir=None,
...	return_gen=True
... )

If you set output_dir to a specific path, the generated images will be saved as .png files named 00000000.png, 00000001.png ... The samplers can be used with any model as long as it is suited. For instance, a GaussianMixtureSampler instance can be used to generate from any model but a VAMPSampler will only be usable with a VAMP model. Check here to see which ones apply to your model. Be carefull that some samplers such as the GaussianMixtureSampler for instance may need to be fitted by calling the fit method before using. Below is an example for the GaussianMixtureSampler.

>>> from pythae.models import AutoModel
>>> from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig
>>> # Retrieve the trained model
>>> my_trained_vae = AutoModel.load_from_folder(
...	'path/to/your/trained/model'
... )
>>> # Define your sampler
... gmm_sampler_config = GaussianMixtureSamplerConfig(
...	n_components=10
... )
>>> my_samper = GaussianMixtureSampler(
...	sampler_config=gmm_sampler_config,
...	model=my_trained_vae
... )
>>> # fit the sampler
>>> gmm_sampler.fit(train_dataset)
>>> # Generate samples
>>> gen_data = my_samper.sample(
...	num_samples=50,
...	batch_size=10,
...	output_dir=None,
...	return_gen=True
... )

Define you own Autoencoder architecture

Pythae provides you the possibility to define your own neural networks within the VAE models. For instance, say you want to train a Wassertstein AE with a specific encoder and decoder, you can do the following:

>>> from pythae.models.nn import BaseEncoder, BaseDecoder
>>> from pythae.models.base.base_utils import ModelOutput
>>> class My_Encoder(BaseEncoder):
...	def __init__(self, args=None): # Args is a ModelConfig instance
...		BaseEncoder.__init__(self)
...		self.layers = my_nn_layers()
...		
...	def forward(self, x:torch.Tensor) -> ModelOutput:
...		out = self.layers(x)
...		output = ModelOutput(
...			embedding=out # Set the output from the encoder in a ModelOutput instance 
...		)
...		return output
...
... class My_Decoder(BaseDecoder):
...	def __init__(self, args=None):
...		BaseDecoder.__init__(self)
...		self.layers = my_nn_layers()
...		
...	def forward(self, x:torch.Tensor) -> ModelOutput:
...		out = self.layers(x)
...		output = ModelOutput(
...			reconstruction=out # Set the output from the decoder in a ModelOutput instance
...		)
...		return output
...
>>> my_encoder = My_Encoder()
>>> my_decoder = My_Decoder()

And now build the model

>>> from pythae.models import WAE_MMD, WAE_MMD_Config
>>> # Set up the model configuration 
>>> my_wae_config = model_config = WAE_MMD_Config(
...	input_dim=(1, 28, 28),
...	latent_dim=10
... )
...
>>> # Build the model
>>> my_wae_model = WAE_MMD(
...	model_config=my_wae_config,
...	encoder=my_encoder, # pass your encoder as argument when building the model
...	decoder=my_decoder # pass your decoder as argument when building the model
... )

important note 1: For all AE-based models (AE, WAE, RAE_L2, RAE_GP), both the encoder and decoder must return a ModelOutput instance. For the encoder, the ModelOutput instance must contain the embbeddings under the key embedding. For the decoder, the ModelOutput instance must contain the reconstructions under the key reconstruction.

important note 2: For all VAE-based models (VAE, BetaVAE, IWAE, HVAE, VAMP, RHVAE), both the encoder and decoder must return a ModelOutput instance. For the encoder, the ModelOutput instance must contain the embbeddings and log-covariance matrices (of shape batch_size x latent_space_dim) respectively under the key embedding and log_covariance key. For the decoder, the ModelOutput instance must contain the reconstructions under the key reconstruction.

Using benchmark neural nets

You can also find predefined neural network architectures for the most common data sets (i.e. MNIST, CIFAR, CELEBA ...) that can be loaded as follows

>>> for pythae.models.nn.benchmark.mnist import (
...	Encoder_Conv_AE_MNIST, # For AE based model (only return embeddings)
...	Encoder_Conv_VAE_MNIST, # For VAE based model (return embeddings and log_covariances)
...	Decoder_Conv_AE_MNIST
... )

Replace mnist by cifar or celeba to access to other neural nets.

Getting your hands on the code

To help you to understand the way pythae works and how you can train your models with this library we also provide tutorials:

  • making_your_own_autoencoder.ipynb shows you how to pass your own networks to the models implemented in pythae Open In Colab

  • models_training folder provides notebooks showing how to train each implemented model and how to sample from it using pythae.samplers.

  • scripts folder provides in particular an example of a training script to train the models on benchmark data sets (mnist, cifar10, celeba ...)

Dealing with issues

If you are experiencing any issues while running the code or request new features/models to be implemented please open an issue on github.

Contributing 🚀

You want to contribute to this library by adding a model, a sampler or simply fix a bug ? That's awesome! Thank you! Please see CONTRIBUTING.md to follow the main contributing guidelines.

Results

Reconstruction

First let's have a look at the reconstructed samples taken from the evaluation set.

Models MNIST CELEBA
Eval data Eval AE
AE AE AE
VAE VAE VAE
Beta-VAE Beta Beta Normal
VAE Lin NF VAE_LinNF VAE_IAF Normal
VAE IAF VAE_IAF VAE_IAF Normal
Disentangled Beta-VAE Disentangled Beta Disentangled Beta
FactorVAE FactorVAE FactorVAE
BetaTCVAE BetaTCVAE BetaTCVAE
IWAE IWAE IWAE
MSSSIM_VAE MSSSIM VAE MSSSIM VAE
WAE WAE WAE
INFO VAE INFO INFO
VAMP VAMP VAMP
SVAE SVAE SVAE
Adversarial_AE AAE AAE
VAE_GAN VAEGAN VAEGAN
VQVAE VQVAE VQVAE
HVAE HVAE HVAE
RAE_L2 RAE L2 RAE L2
RAE_GP RAE GMM RAE GMM
Riemannian Hamiltonian VAE (RHVAE) RHVAE RHVAE RHVAE

Generation

Here, we show the generated samples using using each model implemented in the library and different samplers.

Models MNIST CELEBA
AE + GaussianMixtureSampler AE GMM AE GMM
VAE + NormalSampler VAE Normal VAE Normal
VAE + GaussianMixtureSampler VAE GMM VAE GMM
VAE + TwoStageVAESampler VAE 2 stage VAE 2 stage
VAE + MAFSampler VAE MAF VAE MAF
Beta-VAE + NormalSampler Beta Normal Beta Normal
VAE Lin NF + NormalSampler VAE_LinNF Normal VAE_LinNF Normal
VAE IAF + NormalSampler VAE_IAF Normal VAE IAF Normal
Disentangled Beta-VAE + NormalSampler Disentangled Beta Normal Disentangled Beta Normal
FactorVAE + NormalSampler FactorVAE Normal FactorVAE Normal
BetaTCVAE + NormalSampler BetaTCVAE Normal BetaTCVAE Normal
IWAE + Normal sampler IWAE Normal IWAE Normal
MSSSIM_VAE + NormalSampler MSSSIM_VAE Normal MSSSIM_VAE Normal
WAE + NormalSampler WAE Normal WAE Normal
INFO VAE + NormalSampler INFO Normal INFO Normal
SVAE + HypershereUniformSampler SVAE Sphere SVAE Sphere
VAMP + VAMPSampler VAMP Vamp VAMP Vamp
Adversarial_AE + NormalSampler AAE_Normal AAE_Normal
VAEGAN + NormalSampler VAEGAN_Normal VAEGAN_Normal
VQVAE + MAFSampler VQVAE_MAF VQVAE_MAF
HVAE + NormalSampler HVAE Normal HVAE GMM
RAE_L2 + GaussianMixtureSampler RAE L2 GMM RAE L2 GMM
RAE_GP + GaussianMixtureSampler RAE GMM RAE GMM
Riemannian Hamiltonian VAE (RHVAE) + RHVAE Sampler RHVAE RHVAE RHVAE RHVAE

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pythae-0.0.1.tar.gz (108.3 kB view hashes)

Uploaded Source

Built Distribution

pythae-0.0.1-py3-none-any.whl (197.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page