Skip to main content

Official PyTorch BIVA implementation (BIVA: A Very Deep Hierarchy of Latent Variables for Generative Modeling)

Project description

BIVA (PyTorch)

Official PyTorch BIVA implementation (BIVA: A Very Deep Hierarchy of Latent Variables forGenerative Modeling) for binarized MNIST and CIFAR. The original Tensorflow implementation can be found here.

run the experiments

conda create --name biva python=3.7
conda activate biva
pip install -r requirements.txt
CUDA_VISIBLE_DEVICES=0 python run_deepvae.py --dataset binmnist --q_dropout 0.5 --p_dropout 0.5 --device cuda
CUDA_VISIBLE_DEVICES=0 python run_deepvae.py --dataset cifar10 --q_dropout 0.2 --p_dropout 0 --device cuda

Citation

@article{maale2019biva,
    title={BIVA: A Very Deep Hierarchy of Latent Variables for Generative Modeling},
    author={Lars Maaløe and Marco Fraccaro and Valentin Liévin and Ole Winther},
    year={2019},
    eprint={1902.02102},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}

Pip package

install requirements

  • pytorch 1.3.0
  • torchvision
  • matplotlib
  • tensorboard
  • booster-pytorch==0.0.2

install package

pip install git+https://github.com/vlievin/biva-pytorch.git

build deep VAEs

import torch
from torch.distributions import Bernoulli

from biva import DenseNormal, ConvNormal
from biva import VAE, LVAE, BIVA

# build a 2 layers VAE for binary images

# define the stochastic layers
z = [
    {'N': 8, 'kernel': 5, 'block': ConvNormal},  # z1
    {'N': 16, 'block': DenseNormal}  # z2
]

# define the intermediate layers
# each stage defines the configuration of the blocks for q_(z_{l} | z_{l-1}) and p_(z_{l-1} | z_{l})
# each stage is defined by a sequence of 3 resnet blocks
# each block is degined by a tuple [filters, kernel, stride]
stages = [
    [[64, 3, 1], [64, 3, 1], [64, 3, 2]],
    [[64, 3, 1], [64, 3, 1], [64, 3, 2]]
]

# build the model
model = VAE(tensor_shp=(-1, 1, 28, 28), stages=stages, latents=z, dropout=0.5)

# forward pass and data-dependent initialization
x = torch.empty((8, 1, 28, 28)).uniform_().bernoulli()
data = model(x)  # data = {'x_' : p(x|z), z \sim q(z|x), 'kl': [kl_z1, kl_z2]}

# sample from prior
data = model.sample_from_prior(N=16)  # data = {'x_' : p(x|z), z \sim p(z)}
samples = Bernoulli(logits=data['x_']).sample()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

biva-pytorch-0.1.4.tar.gz (26.1 kB view details)

Uploaded Source

File details

Details for the file biva-pytorch-0.1.4.tar.gz.

File metadata

  • Download URL: biva-pytorch-0.1.4.tar.gz
  • Upload date:
  • Size: 26.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.4

File hashes

Hashes for biva-pytorch-0.1.4.tar.gz
Algorithm Hash digest
SHA256 1c9e12b49146c50d2c8c2f9ae3a8e5aa518f70bbb2023e2d3c3e93798646ada8
MD5 8e3638ea0e6c63a31d3876b51c5325e0
BLAKE2b-256 774ed0511f08908d7fe717a96feee637f5bae5dde3cde26b00f0b14dba5b6b35

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page