package for implementing guided variational autoencoders
Project description
proteovae
This library implements a convenient set of modules for designing and implementing several different variational autoencoder frameworks. So far support is provided for the vanilla VAE, beta-VAE, and the here-presented guided VAE (GVAE).
proteovae also provides a few different model trainers to facilitate the training process, although you can also use standard PyTorch or Lightning as you please. This package was developed as a tool to explore genomics data (hence proteo[mics]-vae), for a much more comprehensive suite of VAE implementations I would point you in the direction of pythae.
News 📢
Version 0.0.1 now on PyPI! ❤️🇮🇹🧑🔬
Quick Access
Installation
To install the latest stable release of this library run the following using pip
$ pip install proteovae
Defining Custom Architectures
In addition to the models provided proteovae.models.base module you can also write your own encoder and decoder architectures for the VAE you're fitting.
>>> from proteovae.models.base import Encoder, Guide, Decoder
>>> from proteovae.models import GuidedConfig, GuidedVAE
>>> import torch
>>> from torch import nn
...
>>> input_dim = 64
>>> latent_dim = 10
>>> guided_dim = 1
>>> n_classes = 2 # dummy
...
>>> config = GuidedConfig(
... input_dim = input_dim,
... latent_dim = latent_dim,
... guided_dim = guided_dim
... )
...
>>> #using proteovae.models objects
>>> enc = Encoder(
... input_dim=input_dim,
... latent_dim=latent_dim,
... hidden_dims = [32,16,]
... )
>>> dec = Decoder(
... output_dim = input_dim,
... latent_dim = latent_dim,
... hidden_dims = [16,32,]
... )
...
>>> gvae1 = GuidedVAE(
... model_config = config,
... encoder = enc,
... decoder = dec,
... guide = Guide(dim_in = guided_dim, dim_out = n_classes)
)
...
>>> #or with generic torch objects
>>> class CustomDecoder(nn.Module):
... def __init__(self, **kwargs):
... super().__init__(**kwargs)
... self.fwd_block = nn.Sequential(
... nn.Linear(latent_dim, 2*latent_dim),
... nn.Tanh(),
... nn.Linear(2*latent_dim, input_dim),
... )
... def forward(self, x):
... return self.fwd_block(x)
...
>>> custom_dec = CustomDecoder()
>>> gvae2 = GuidedVAE(
... model_config = config,
... encoder = enc,
... decoder = custom_dec,
... guide = Guide(dim_in = guided_dim, dim_out = n_classes)
Model Training
Two different proteovae.trainers Trainers (BaseTrainer and ScheduledTrainer) are provided to bundle up a lot of the annoying aspects of defining training loops in PyTorch. As shown below, their implementation is fairly straightfoward
>>> from proteovae.trainers import ScheduledTrainer
>>> from torch import optim
...
>>> #define any proteovae.models objec and torch data loaders
>>> model = # ...
>>> train_loader = # ...
>>> val_loader = # ...
...
>>> #define optimizer and lr_scheduler
>>> n_epochs = 5
>>> optimizer = optim.Adam(model.parameters(),
... lr=1e-03)
>>> scheduler = optim.lr_scheduler.LinearLR(optimizer,
... start_factor=1.0,
... end_factor=0.33,
... total_iters=n_epochs*len(train_loader))
... #trainer init
>>> trainer = ScheduledTrainer(model, optimizer, scheduler)
...
>>> #train
>>> trainer.train(train_loader, n_epochs, val_data = val_data)
Tutorials
- nonlinear_pca.ipynb motivates a generic use case for the GVAE framework
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file proteovae-0.0.1.tar.gz.
File metadata
- Download URL: proteovae-0.0.1.tar.gz
- Upload date:
- Size: 14.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e4759f68393ab4e20fabeb436d3db84a3aacb2a380fe92c206ab44066cb790f0
|
|
| MD5 |
fa04822f73014260f9240cb0494df470
|
|
| BLAKE2b-256 |
083d1b3ca15e061c6eb6e50c39bc379ff3dc72793dcca4e628180507965c5593
|
File details
Details for the file proteovae-0.0.1-py3-none-any.whl.
File metadata
- Download URL: proteovae-0.0.1-py3-none-any.whl
- Upload date:
- Size: 14.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b200fbfbb9e2a3f439bce267a6ff21fd115cd45fc409862994553c0e60393558
|
|
| MD5 |
1bdf9278517e823d172ed5ad64a7411a
|
|
| BLAKE2b-256 |
823639170ebc9bb2df8a6bab2782916973fe7b55d83104dab8566ff511813b36
|