Framework to ease training of generative models based on TensorFlow
Project description
SimpleGAN
Framework to ease training of generative models
SimpleGAN is a framework based on TensorFlow to make training of generative models easier. SimpleGAN provides high level APIs with customizability options to user which allows them to train a generative models with few lines of code or the user can reuse modules from the exisiting architectures to run custom training loops and experiments.
Requirements
Make sure you have the following packages installed
Installation
Latest stable release:
$ pip install simplegan
Latest Development release:
$ pip install git+https://github.com/grohith327/simplegan.git
Getting Started
DCGAN
from simplegan.gan import DCGAN
## initialize model
gan = DCGAN()
## load train data
train_ds = gan.load_data(use_mnist = True)
## get samples from the data object
samples = gan.get_sample(train_ds, n_samples = 5)
## train the model
gan.fit(train_ds = train_ds)
## get generated samples from model
generated_samples = gan.generate_samples(n_samples = 5)
Custom training loops for GANs
from simplegan.gan import Pix2Pix
## initialize model
gan = Pix2Pix()
## get generator module of Pix2Pix
generator = gan.generator() ## A tf.keras model
## get discriminator module of Pix2Pix
discriminator = gan.discriminator() ## A tf.keras model
## training loop
with tf.GradientTape() as tape:
""" Custom training loops """
Convolutional Autoencoder
from simplegan.autoencoder import ConvolutionalAutoencoder
## initialize autoencoder
autoenc = ConvolutionalAutoencoder()
## load train and test data
train_ds, test_ds = autoenc.load_data(use_cifar10 = True)
## get sample from data object
train_sample = autoenc.get_sample(data = train_ds, n_samples = 5)
test_sample = autoenc.get_sample(data = test_ds, n_samples = 1)
## train the autoencoder
autoenc.fit(train_ds = train_ds, epochs = 5, optimizer = 'RMSprop', learning_rate = 0.002)
## get generated test samples from model
generated_samples = autoenc.generate_samples(test_ds = test_ds.take(1))
To have a look at more examples in detail, check here
Documentation
Check out the docs page
Provided models
Model | Generated Images |
---|---|
Vanilla Autoencoder | None |
Convolutional Autoencoder | |
Variational Autoencoder [Paper] | |
Vector Quantized - Variational Autoencoder [Paper] | |
Vanilla GAN [Paper] | |
DCGAN [Paper] | |
WGAN [Paper] | |
CGAN [Paper] | |
InfoGAN [Paper] | |
Pix2Pix [Paper] | |
CycleGAN [Paper] | |
3DGAN(VoxelGAN) [Paper] | |
Self-Attention GAN(SAGAN) [Paper] |
Contributing
We appreciate all contributions. If you are planning to perform bug-fixes, add new features or models, please file an issue and discuss before making a pull request.
Citation
@software{simplegan,
author = {{Rohith Gandhi et al.}},
title = {simplegan},
url = {https://simplegan.readthedocs.io},
version = {0.2.8},
}
Contributors
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file simplegan-0.2.9.tar.gz
.
File metadata
- Download URL: simplegan-0.2.9.tar.gz
- Upload date:
- Size: 33.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/44.0.0 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.6.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e0ab7f98125e927960f785a5297cd312963a112d37a010e80b13d72be7847e1c |
|
MD5 | 2776f58e1a7a725464d055da9bef89f2 |
|
BLAKE2b-256 | cec7a294ec30023f40f727be1e1c72664f40f16330b676a72e09c4b417d79467 |