LangVAE: Large Language VAEs made simple
Project description
LangVAE: Large Language VAEs made simple
LangVAE is a Python library for training and running language models using Variational Autoencoders (VAEs). It provides an easy-to-use interface to train VAEs on text data, allowing users to customize the model architecture, loss function, and training parameters.
Installation
To install LangVAE, simply run:
pip install langvae
This will install all necessary dependencies and set up the package for use in your Python projects.
Usage
Here's a basic example of how to train a VAE on text data using LangVAE:
from langvae import LangVAE
from langvae.encoders import SentenceEncoder
from langvae.decoders import SentenceDecoder
from langvae.data_conversion.tokenization import TokenizedDataSet
from langvae.pipelines import LanguageTrainingPipeline
from langvae.trainers import CyclicalScheduleKLThresholdTrainerConfig
from saf_datasets import EntailmentBankDataSet
DEVICE = "cuda"
# Load pre-trained sentence encoder and decoder models.
decoder = SentenceDecoder("gpt2", latent_size=32, max_len=32, device=DEVICE)
encoder = SentenceEncoder("bert-base-cased", latent_size=32, decoder.tokenizer, device=DEVICE)
# Select explanatory sentences from the EntailmentBank dataset.
dataset = [
sent for sent in EntailmentBankDataSet()
if (sent.annotations["type"] == "answer" or
sent.annotations["type"].startswith("context"))
]
# Set training and evaluation datasets with auto tokenization.
eval_size = int(0.1 * len(dataset))
train_dataset = TokenizedDataSet(dataset[:-eval_size], decoder.tokenizer, decoder.max_len)
eval_dataset = TokenizedDataSet(dataset[-eval_size:], decoder.tokenizer, decoder.max_len)
# Define VAE model configuration
model_config = VAEConfig(
input_dim=(train_dataset[0]["data"].shape[-2], train_dataset[0]["data"].shape[-1]),
latent_dim=32
)
# Initialize LangVAE model
model = LangVAE(model_config, encoder, decoder)
# Train VAE on explanatory sentences
training_config = CyclicalScheduleKLThresholdTrainerConfig(
output_dir='expl_vae',
num_epochs=5,
learning_rate=1e-4,
per_device_train_batch_size=50,
per_device_eval_batch_size=50,
steps_saving=1,
optimizer_cls="AdamW",
scheduler_cls="ReduceLROnPlateau",
scheduler_params={"patience": 5, "factor": 0.5},
max_beta=1.0,
n_cycles=40,
target_kl=2.0
)
pipeline = LanguageTrainingPipeline(
training_config=training_config,
model=model
)
pipeline(
train_data=train_dataset,
eval_data=eval_dataset
)
This example loads pre-trained encoder and decoder models, defines a VAE model configuration, initializes the LangVAE model, and trains it on text data using a custom training pipeline.
License
LangVAE is licensed under the GPLv3 License. See the LICENSE file for details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file langvae-0.2.3.tar.gz
.
File metadata
- Download URL: langvae-0.2.3.tar.gz
- Upload date:
- Size: 24.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3948d71965a15131873e63110beb323dfb8a5c129bbcdc66b015ee5853487a36 |
|
MD5 | e945bdf555c4489ba77cd98c1e836a75 |
|
BLAKE2b-256 | 93a3801edf4c112f014a001606bf2326f1540ead2eca2fdcfb4cbaa864408373 |
File details
Details for the file langvae-0.2.3-py3-none-any.whl
.
File metadata
- Download URL: langvae-0.2.3-py3-none-any.whl
- Upload date:
- Size: 28.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 184c2ceb90b4b683cb6d93eafcd56d04712b1e466413e2744ade228d53f36ddf |
|
MD5 | f682e165a0b2523249b171ce7e9ea4c6 |
|
BLAKE2b-256 | 54e21608d86f7738eab260969bd636ac8b0684b1056e17dbd5450781169baf1c |