Skip to main content

LangVAE: Large Language VAEs made simple

Project description

LangVAE: Large Language VAEs made simple

LangVAE is a Python library for training and running language models using Variational Autoencoders (VAEs). It provides an easy-to-use interface to train VAEs on text data, allowing users to customize the model architecture, loss function, and training parameters.

Installation

To install LangVAE, simply run:

pip install langvae

This will install all necessary dependencies and set up the package for use in your Python projects.

Usage

Here's a basic example of how to train a VAE on text data using LangVAE:

from pythae.models.vae import VAEConfig
from langvae import LangVAE
from langvae.encoders import SentenceEncoder
from langvae.decoders import SentenceDecoder
from langvae.data_conversion.tokenization import TokenizedDataSet
from langvae.pipelines import LanguageTrainingPipeline
from langvae.trainers import CyclicalScheduleKLThresholdTrainerConfig
from saf_datasets import EntailmentBankDataSet

DEVICE = "cuda"
LATENT_SIZE = 32
MAX_SENT_LEN = 32

# Load pre-trained sentence encoder and decoder models.
decoder = SentenceDecoder("gpt2", LATENT_SIZE, MAX_SENT_LEN, device=DEVICE)
encoder = SentenceEncoder("bert-base-cased", LATENT_SIZE, decoder.tokenizer, device=DEVICE)

# Select explanatory sentences from the EntailmentBank dataset.
dataset = [
    sent for sent in EntailmentBankDataSet()
    if (sent.annotations["type"] == "answer" or 
        sent.annotations["type"].startswith("context"))
]

# Set training and evaluation datasets with auto tokenization.
eval_size = int(0.1 * len(dataset))
train_dataset = TokenizedDataSet(dataset[:-eval_size], decoder.tokenizer, decoder.max_len)
eval_dataset = TokenizedDataSet(dataset[-eval_size:], decoder.tokenizer, decoder.max_len)


# Define VAE model configuration
model_config = VAEConfig(
    input_dim=(train_dataset[0]["data"].shape[-2], train_dataset[0]["data"].shape[-1]),
    latent_dim=LATENT_SIZE
)

# Initialize LangVAE model
model = LangVAE(model_config, encoder, decoder)

# Train VAE on explanatory sentences
training_config = CyclicalScheduleKLThresholdTrainerConfig(
    output_dir='expl_vae',
    num_epochs=5,
    learning_rate=1e-4,
    per_device_train_batch_size=50,
    per_device_eval_batch_size=50,
    steps_saving=1,
    optimizer_cls="AdamW",
    scheduler_cls="ReduceLROnPlateau",
    scheduler_params={"patience": 5, "factor": 0.5},
    max_beta=1.0,
    n_cycles=40,
    target_kl=2.0
)

pipeline = LanguageTrainingPipeline(
    training_config=training_config,
    model=model
)

pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset
)

This example loads pre-trained encoder and decoder models, defines a VAE model configuration, initializes the LangVAE model, and trains it on text data using a custom training pipeline.

License

LangVAE is licensed under the GPLv3 License. See the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

langvae-0.2.6.tar.gz (27.3 kB view details)

Uploaded Source

Built Distribution

langvae-0.2.6-py3-none-any.whl (33.4 kB view details)

Uploaded Python 3

File details

Details for the file langvae-0.2.6.tar.gz.

File metadata

  • Download URL: langvae-0.2.6.tar.gz
  • Upload date:
  • Size: 27.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.5

File hashes

Hashes for langvae-0.2.6.tar.gz
Algorithm Hash digest
SHA256 1fc53b288ee59da9598ec162ac44a2f4e3755a75134b26a2db72210d690eaefb
MD5 29f651f2e583aea5745ca33ba32ddaed
BLAKE2b-256 ac3549866b8674b8ac09b3939bba40f5eb948b82d8dcf9b3081838572fb0a25e

See more details on using hashes here.

File details

Details for the file langvae-0.2.6-py3-none-any.whl.

File metadata

  • Download URL: langvae-0.2.6-py3-none-any.whl
  • Upload date:
  • Size: 33.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.5

File hashes

Hashes for langvae-0.2.6-py3-none-any.whl
Algorithm Hash digest
SHA256 f8772ba06b9091c166bea35d5015c0716e978646c3142d9d1fcc6e87f02555e9
MD5 fd5c298bf392db43b304011587d78d2d
BLAKE2b-256 5fa7e7200c5dfab3add8f8b90ddf7fdcdb9787c1d4b2dcc90883d534c3ae160a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page