Skip to main content

Interpolate between discrete sequences.

Project description

Transformer-VAE (WIP)

Diagram of the a State Autoencoder

Transformer-VAE's learn smooth latent spaces of discrete sequences without any explicit rules in their decoders.

This can be used for program synthesis, drug discovery, music generation and much more!

To see how it works checkout this blog post.

This repo is in active development but I should be coming out with full a release soon.


Install using pip:

pip install transformer_vae


You can exececute the module to easily train it on your own data.

python -m transformer_vae \
    --project_name="T5-VAE" \
    --output_dir=poet \
    --do_train \
    --huggingface_dataset=poems \

Or you can import Transformer-VAE to use as a package much like a Huggingface model.

from transformer_vae import T5_VAE_Model

model = T5_VAE_Model.from_pretrained('t5-vae-poet')


Setup Weights & Biasis for logging, see client.

Get a dataset to model, must be represented with text. This is what we will be interpolating over.

This can be a text file with each line representing a sample.

python -m transformer_vae \
    --project_name="T5-VAE" \
    --output_dir=poet \
    --do_train \
    --train_file=poems.txt \

Alternatively seperate each sample with a line containing only <|endoftext|> seperating samples:

python -m transformer_vae \
    --project_name="T5-VAE" \
    --output_dir=poet \
    --do_train \
    --train_file=poems.txt \

Alternatively provide a Huggingface dataset.

python -m transformer_vae \
    --project_name="T5-VAE" \
    --output_dir=poet \
    --do_train \
    --dataset=poems \
    --content_key text

Experiment with different parameters.

Once finished upload to huggingface model hub.


Explore the produced latent space using Colab_T5_VAE.ipynb or vising this Colab page.


Install with tests:

pip install -e .[test]

Possible contributions to make:

  1. Could the docs be more clear? Would it be worth having a docs site/blog?
  2. Use a Funnel transformer encoder, is it more efficient?
  3. Allow defining alternative tokens set.
  4. Store the latent codes from the previous step to use in MMD loss so smaller batch sizes are possible.

Feel free to ask what would be useful!

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformer_vae-0.0.2.tar.gz (24.3 kB view hashes)

Uploaded Source

Built Distribution

transformer_vae-0.0.2-py3-none-any.whl (27.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page