Keras GAN Library
Project description
keragan
Keras implementation of GANs
This library provides some simple infrastructure to define and train Generative Adversarial Networks in Keras. It can also be used from the command line.
Installation
The simplest way to install is:
pip install keragan
Some Images produced by keragan
Images were trained on WikiArt.org dataset taken from here.
Minimalistic Landcape, 2020 | Spring Dawn, 2020 | Water Still 2, 2020 |
--- | --- | --- |
Training GAN from the Command-Line
To start training, you can use the following command-line:
python -m keragan.trainer c:\dataset --size 1024 --model_path .\models --samples_path .\samples --latent_dim 100 --epochs 1000
You can find out more about other parameters by calling the program with --help
option.
Important things to note:
--size
must be power of 2, suitable values are 64, 128, 256, 512 and 1024. Higher resoltions are likely not to give good results.- You can use
--lr
to set learning rate, default value is 0.001. Smaller learning rates yield better results, but may significantly increase training time.
Generating Images
Once you have trained the model, you can use the generator model to produce new random images. To do that from a command line, you can use the following:
python -m keragan.generate --file ./models/gen_1100.hdf5 --out ./samples --n 100
Use --help
option to find out more about different options.
Architecture
The library is structured around few core classes:
GAN
is used to represent a GAN, withgenerator
anddiscriminator
fields that define corresponding networks.GAN
itself is abstract, and any subclass should definecreate_generator()
andcreate_discriminator()
functions. This class is also responsible for loading/saving networks to disk, and it can also generate sample images usingsample_images
method.DCGAN
is currently the only subclass, implementing Deep Convolutional GAN.ImageDataset
is a class defining the process of loading initial dataset from disk, resizing it to specified size, filtering out bad images, etc.GANTrainer
is responsible to training a GAN, i.e. running epoch loop and periodically storing samples and network weights to disk.
The actual training code looks like this:
gan = keragan.DCGAN(args)
imsrc = keragan.ImageDataset(args)
imsrc.load()
if args.sample_images:
imsrc.sample_images()
train = keragan.GANTrainer(image_dataset=imsrc,gan=gan,args=args)
def callbk(tr):
if args.visual_inspection_interval and tr.gan.epoch % args.visual_inspection_interval == 0:
res = tr.gan.sample_images(n=2)
fig,ax = plt.subplots(1,len(res))
for i,v in enumerate(res):
ax[i].imshow(v[0])
plt.show()
train.train(callbk)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file keragan-0.0.3.tar.gz
.
File metadata
- Download URL: keragan-0.0.3.tar.gz
- Upload date:
- Size: 11.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 589aba934ac14379256d4ded0e029c0f225896d4ce2df2e267cef1022964f7f4 |
|
MD5 | 4d64b054a9dabc5a5fc66c7aa703cc55 |
|
BLAKE2b-256 | ea09142a257b38b9e0a8619927ba266bd188a1a0f82f5e936f366d5e3461abe4 |