Skip to main content

package for implementing guided variational autoencoders

Project description

Python Documentation Status

Documentation

proteovae

This library implements a convenient set of modules for designing and implementing several different variational autoencoder frameworks. So far support is provided for the vanilla VAE, beta-VAE, and the here-presented guided VAE (GVAE).

proteovae also provides a few different model trainers to facilitate the training process, although you can also use standard PyTorch or Lightning as you please. This package was developed as a tool to explore genomics data (hence proteo[mics]-vae), for a much more comprehensive suite of VAE implementations I would point you in the direction of pythae.

News 📢

Version 0.0.1 now on PyPI! ❤️🇮🇹🧑‍🔬

Quick Access

Installation

To install the latest stable release of this library run the following using pip

$ pip install proteovae

Defining Custom Architectures

In addition to the models provided proteovae.models.base module you can also write your own encoder and decoder architectures for the VAE you're fitting.

>>> from proteovae.models.base import Encoder, Guide, Decoder
>>> from proteovae.models import GuidedConfig, GuidedVAE
>>> import torch 
>>> from torch import nn 
...
>>> input_dim = 64
>>> latent_dim = 10
>>> guided_dim = 1
>>> n_classes = 2 # dummy 
...
>>> config = GuidedConfig(
...     input_dim = input_dim,
...     latent_dim = latent_dim, 
...     guided_dim = guided_dim
... )
...
>>> #using proteovae.models objects 
>>> enc = Encoder(
...         input_dim=input_dim, 
...         latent_dim=latent_dim, 
...         hidden_dims = [32,16,]
... )
>>> dec = Decoder(
...         output_dim = input_dim, 
...         latent_dim = latent_dim, 
...         hidden_dims = [16,32,]
... )
...
>>> gvae1 = GuidedVAE(
...         model_config = config,
...         encoder = enc,
...         decoder = dec, 
...         guide = Guide(dim_in = guided_dim, dim_out = n_classes)
)
...
>>> #or with generic torch objects 
>>> class CustomDecoder(nn.Module):
...     def __init__(self, **kwargs):
...         super().__init__(**kwargs)
...         self.fwd_block = nn.Sequential(
...         nn.Linear(latent_dim, 2*latent_dim),
...         nn.Tanh(),
...         nn.Linear(2*latent_dim, input_dim),
...     )
...     def forward(self, x):
...         return self.fwd_block(x)
...
>>> custom_dec = CustomDecoder()
>>> gvae2 = GuidedVAE(
...         model_config = config,
...         encoder = enc,
...         decoder = custom_dec, 
...         guide = Guide(dim_in = guided_dim, dim_out = n_classes)

Model Training

Two different proteovae.trainers Trainers (BaseTrainer and ScheduledTrainer) are provided to bundle up a lot of the annoying aspects of defining training loops in PyTorch. As shown below, their implementation is fairly straightfoward

>>> from proteovae.trainers import ScheduledTrainer 
>>> from torch import optim 
...
>>> #define any proteovae.models objec and torch data loaders 
>>> model = # ... 
>>> train_loader = # ... 
>>> val_loader = # ... 
...
>>> #define optimizer and lr_scheduler 
>>> n_epochs = 5 
>>> optimizer = optim.Adam(model.parameters(), 
...                        lr=1e-03)
>>> scheduler = optim.lr_scheduler.LinearLR(optimizer, 
...                                         start_factor=1.0, 
...                                         end_factor=0.33, 
...                                         total_iters=n_epochs*len(train_loader))
... #trainer init 
>>> trainer = ScheduledTrainer(model, optimizer, scheduler)
...
>>> #train
>>> trainer.train(train_loader, n_epochs, val_data = val_data)

Tutorials

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

proteovae-0.0.1.tar.gz (14.9 kB view hashes)

Uploaded Source

Built Distribution

proteovae-0.0.1-py3-none-any.whl (14.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page