package for implementing guided variational autoencoders
Project description
proteovae
This library implements a convenient set of modules for designing and implementing several different variational autoencoder frameworks. So far support is provided for the vanilla VAE, beta-VAE, and the here-presented guided VAE (GVAE).
proteovae
also provides a few different model trainers to facilitate the training process, although you can also use standard PyTorch or Lightning as you please. This package was developed as a tool to explore genomics data (hence proteo[mics]-vae), for a much more comprehensive suite of VAE implementations I would point you in the direction of pythae.
News 📢
Version 0.0.1 now on PyPI! ❤️🇮🇹🧑🔬
Quick Access
Installation
To install the latest stable release of this library run the following using pip
$ pip install proteovae
Defining Custom Architectures
In addition to the models provided proteovae.models.base
module you can also write your own encoder and decoder architectures for the VAE you're fitting.
>>> from proteovae.models.base import Encoder, Guide, Decoder
>>> from proteovae.models import GuidedConfig, GuidedVAE
>>> import torch
>>> from torch import nn
...
>>> input_dim = 64
>>> latent_dim = 10
>>> guided_dim = 1
>>> n_classes = 2 # dummy
...
>>> config = GuidedConfig(
... input_dim = input_dim,
... latent_dim = latent_dim,
... guided_dim = guided_dim
... )
...
>>> #using proteovae.models objects
>>> enc = Encoder(
... input_dim=input_dim,
... latent_dim=latent_dim,
... hidden_dims = [32,16,]
... )
>>> dec = Decoder(
... output_dim = input_dim,
... latent_dim = latent_dim,
... hidden_dims = [16,32,]
... )
...
>>> gvae1 = GuidedVAE(
... model_config = config,
... encoder = enc,
... decoder = dec,
... guide = Guide(dim_in = guided_dim, dim_out = n_classes)
)
...
>>> #or with generic torch objects
>>> class CustomDecoder(nn.Module):
... def __init__(self, **kwargs):
... super().__init__(**kwargs)
... self.fwd_block = nn.Sequential(
... nn.Linear(latent_dim, 2*latent_dim),
... nn.Tanh(),
... nn.Linear(2*latent_dim, input_dim),
... )
... def forward(self, x):
... return self.fwd_block(x)
...
>>> custom_dec = CustomDecoder()
>>> gvae2 = GuidedVAE(
... model_config = config,
... encoder = enc,
... decoder = custom_dec,
... guide = Guide(dim_in = guided_dim, dim_out = n_classes)
Model Training
Two different proteovae.trainers
Trainers (BaseTrainer
and ScheduledTrainer
) are provided to bundle up a lot of the annoying aspects of defining training loops in PyTorch. As shown below, their implementation is fairly straightfoward
>>> from proteovae.trainers import ScheduledTrainer
>>> from torch import optim
...
>>> #define any proteovae.models objec and torch data loaders
>>> model = # ...
>>> train_loader = # ...
>>> val_loader = # ...
...
>>> #define optimizer and lr_scheduler
>>> n_epochs = 5
>>> optimizer = optim.Adam(model.parameters(),
... lr=1e-03)
>>> scheduler = optim.lr_scheduler.LinearLR(optimizer,
... start_factor=1.0,
... end_factor=0.33,
... total_iters=n_epochs*len(train_loader))
... #trainer init
>>> trainer = ScheduledTrainer(model, optimizer, scheduler)
...
>>> #train
>>> trainer.train(train_loader, n_epochs, val_data = val_data)
Tutorials
- nonlinear_pca.ipynb motivates a generic use case for the GVAE framework
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for proteovae-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b200fbfbb9e2a3f439bce267a6ff21fd115cd45fc409862994553c0e60393558 |
|
MD5 | 1bdf9278517e823d172ed5ad64a7411a |
|
BLAKE2b-256 | 823639170ebc9bb2df8a6bab2782916973fe7b55d83104dab8566ff511813b36 |