Skip to main content

LangVAE: Large Language VAEs made simple

Project description

LangVAE: Large Language VAEs made simple

LangVAE is a Python library for training and running language models using Variational Autoencoders (VAEs). It provides an easy-to-use interface to train VAEs on text data, allowing users to customize the model architecture, loss function, and training parameters.

Installation

To install LangVAE, simply run:

pip install langvae

This will install all necessary dependencies and set up the package for use in your Python projects.

Usage

Here's a basic example of how to train a VAE on text data using LangVAE:

from pythae.models.vae import VAEConfig
from saf_datasets import EntailmentBankDataSet
from langvae import LangVAE
from langvae.encoders import SentenceEncoder
from langvae.decoders import SentenceDecoder
from langvae.data_conversion.tokenization import TokenizedDataSet
from langvae.pipelines import LanguageTrainingPipeline
from langvae.trainers import CyclicalScheduleKLThresholdTrainerConfig
from langvae.trainers.training_callbacks import TensorBoardCallback

DEVICE = "cuda"
LATENT_SIZE = 128
MAX_SENT_LEN = 32

# Load pre-trained sentence encoder and decoder models.
decoder = SentenceDecoder("gpt2", LATENT_SIZE, MAX_SENT_LEN, device=DEVICE, device_map="auto")
encoder = SentenceEncoder("bert-base-cased", LATENT_SIZE, decoder.tokenizer, caching=True, device=DEVICE)

# Select explanatory sentences from the EntailmentBank dataset.
dataset = [
    sent for sent in EntailmentBankDataSet()
    if (sent.annotations["type"] == "answer" or
        sent.annotations["type"].startswith("context"))
]

# Set training and evaluation datasets with auto tokenization.
eval_size = int(0.1 * len(dataset))
train_dataset = TokenizedDataSet(sorted(dataset[:-eval_size], key=lambda x: len(x.surface), reverse=True),
                                 decoder.tokenizer, decoder.max_len, caching=True,
                                 cache_persistence=f"eb_train_tok-gpt2_cache.jsonl")
eval_dataset = TokenizedDataSet(sorted(dataset[-eval_size:], key=lambda x: len(x.surface), reverse=True),
                                decoder.tokenizer, decoder.max_len, caching=True,
                                cache_persistence=f"eb_eval_tok-gpt2_cache.jsonl")


# Define VAE model configuration
model_config = VAEConfig(latent_dim=LATENT_SIZE)

# Initialize LangVAE model
model = LangVAE(model_config, encoder, decoder)

exp_label = f"eb-langvae-bert-gpt2-{LATENT_SIZE}"

# Train VAE on explanatory sentences
training_config = CyclicalScheduleKLThresholdTrainerConfig(
    output_dir=exp_label,
    num_epochs=20,
    learning_rate=1e-3,
    per_device_train_batch_size=50,
    per_device_eval_batch_size=50,
    steps_saving=5,
    optimizer_cls="AdamW",
    scheduler_cls="ReduceLROnPlateau",
    scheduler_params={"patience": 5, "factor": 0.5},
    max_beta=1.0,
    n_cycles=16,  # num_epochs * 0.8
    target_kl=2.0,
    keep_best_on_train=True
)

pipeline = LanguageTrainingPipeline(
    training_config=training_config,
    model=model
)

# Monitor the training progress with `tensorboard --logdir=runs &`
tb_callback = TensorBoardCallback(exp_label)

pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset,
    callbacks=[tb_callback]
)

This example loads pre-trained encoder and decoder models, defines a VAE model configuration, initializes the LangVAE model, and trains it on text data using a custom training pipeline.

License

LangVAE is licensed under the GPLv3 License. See the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

langvae-0.6.12.tar.gz (34.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

langvae-0.6.12-py3-none-any.whl (37.8 kB view details)

Uploaded Python 3

File details

Details for the file langvae-0.6.12.tar.gz.

File metadata

  • Download URL: langvae-0.6.12.tar.gz
  • Upload date:
  • Size: 34.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.5

File hashes

Hashes for langvae-0.6.12.tar.gz
Algorithm Hash digest
SHA256 ff882986c84440d77e1d31a4a8436ade00d25430fd3a7cac19cda46e4983a784
MD5 9c6a08be099e96d334c0bad6908fe0f7
BLAKE2b-256 db07cf3cf96e6f5996ec892d5be6a29efd1c32db3ba75b9cd86266e9712f81dc

See more details on using hashes here.

File details

Details for the file langvae-0.6.12-py3-none-any.whl.

File metadata

  • Download URL: langvae-0.6.12-py3-none-any.whl
  • Upload date:
  • Size: 37.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.5

File hashes

Hashes for langvae-0.6.12-py3-none-any.whl
Algorithm Hash digest
SHA256 4be1da9be2491b75663ee06e55cefe8d8309cb745dfe7fb7d0676d75cf5acd71
MD5 06f2d60270467ee9d31197ac64afa48b
BLAKE2b-256 6ebb5513e71939c9c145b20d8df3b6d64370528991d5397da13da737a9d6d01b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page