Skip to main content

Visualization tool for generative models

Project description

GEnerative MOdel VIsualization (gemovi)

Installation

Note! Packages found in requirements-pinned.txt are fully pinned, meaning you will get the exact same versions as the ones used in development and there is a high likelihood that everything works nicely (if gemovi is the only package you install). However, several versions of each package are probably also compatible. The only exception to this is the packages found in requirements-unreleased.txt, which must come from git until their next PyPI release.

Also note! If you want GPUs to work with pytorch, make sure you have the correct CUDA+torch version installed (see here) before attempting to pip install -e gemvoi.

In any case, after cloning this package, you can install it as follows:

# *Please* use a virtual environment. It will make package management easier
# and you will not have to worry about messing up your system Python.
# Any python virtual environment manager will do, but I use (mini)conda
conda create -n gemovi-env python=3.10 -y
conda activate gemovi-env

git clone https://gitlab.com/ntjess/gemovi.git
# See discussion above about installing pinned requirements
# pip install -r ./gemovi/requirements-pinned.txt

# See discussion about installing pytorch with GPU support
# before installing gemovi
# Use "train" if you want to train a model as well as visualize
pip install -e  "./gemovi[train]"
# Note that one of PyQt5, PyQt6, PySide2, PySide6 is also required. This
# check will fail if you do not have one of these installed.
python -c "import pyqtgraph"

# Finally check if the visualization entrypoint works
python -m gemovi.viz --help

Usage

gemovi supports training and visualizing various forms of VAEs (along with DCGAN -- a deep convolutional generative adversarial network). It supports a config file very much like PyTorch-VAE to specify where data lives, how large to make the input images, etc. While most parameters are only relevant during training, some, such as model hyperparameters, are useful to spawn the rigth network during visualization and weight loading.

Training

Assume you have training images stored in path/to/images. This folder should be full of supported image formats, i.e., .jpg, .png, etc. You can train a VAE with the following command:

python -m gemovi.vae.train --config path/to/config.yaml

where path/to/config.yaml is a config file that looks like this. If your config file lives in gemovi/common/configs/, you can simply pass the filename instead of the full path.

Note that you must specify data_params>data_path to be path/to/images and trainer_params>model_name to be the desired VAE architecture. If your hardware is different from a 1-GPU setup, be sure to change trainer_params>{accelerator,devices} accordingly.

If you wish to resume training from a previous run, specify the saved .ckpt file in trainer_params>ckpt_path.

Note this procedure is almost exactly the same for a DCGAN. The only difference is the top-level command:

python -m gemovi.dcgan.train --config path/to/config.yaml

Assuming trainer_params>log_dir was unchanged, your current directory should begin populating with model checkpoints, sample reconstructions, and tensorboard logs under ./lightning_logs.

Visualization

[Note: If the fire package is installed, the --help command will print much more detailed argument information.]

Assume you trained a MSSIMVAE previously with python -m gemovi.vae.train --config path/to/config.yaml. You can visualize the latent space with the following command:

python -m gemovi.viz --model_class MSSIMVAE --config_file path/to/config.yaml --weights_file path/to/checkpoint.ckpt

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gemovi-0.1.3.tar.gz (58.1 kB view hashes)

Uploaded Source

Built Distribution

gemovi-0.1.3-py3-none-any.whl (77.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page