Vector-quantized time-series generation with a bidirectional prior model.
Project description
TimeVQVAE
This is an official Github repository for the PyTorch implementation of TimeVQVAE from our paper "Vector Quantized Time Series Generation with a Bidirectional Prior Model", AISTATS 2023.
TimeVQVAE is a robust time series generation model that utilizes vector quantization for data compression into the discrete latent space (stage1) and a bidirectional transformer for the prior learning (stage2).
Installation
Install from PyPI with uv:
uv add timevqvae
Usage
Stage 1
Example of running VQVAE with a dummy 1D time-series input (batch, channels, length):
import torch
from timevqvae.vqvae import VQVAE
vqvae = VQVAE(
in_channels=1,
input_length=128,
n_fft=4,
init_dim=4,
hid_dim=128,
downsampled_width_l=8,
downsampled_width_h=32,
encoder_n_resnet_blocks=2,
decoder_n_resnet_blocks=2,
codebook_size_l=1024,
codebook_size_h=1024,
kmeans_init=True,
codebook_dim=8,
)
x = torch.randn(4, 1, 128) # (batch, channels, length)
out = vqvae(x)
print(out.x_recon.shape) # (4, 1, 128)
print(out.recons_loss.keys()) # dict_keys(['LF.time', 'HF.time'])
print(out.vq_losses.keys()) # dict_keys(['LF', 'HF'])
print(out.perplexities.keys()) # dict_keys(['LF', 'HF'])
Stage 2
Example of running MaskGIT for:
- training loss computation (internally calls
compute_mask_prediction_loss) - token sampling with
iterative_decodingand decoding sampled tokens back to time series.
import torch
from timevqvae.maskgit import MaskGIT, PriorModelConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
# Reuse `vqvae` from the Stage-1 example above.
# For Stage-2, this should be a pretrained Stage-1 model with trained weights loaded.
# Example:
# vqvae.load_state_dict(torch.load("vqvae_stage1.pt", map_location=device))
# vqvae = vqvae.to(device).eval()
# Stage-2 prior model.
# n_classes is dataset-dependent (example: 2 classes).
maskgit = MaskGIT(
vqvae=vqvae,
lf_choice_temperature=10.0,
hf_choice_temperature=0.0,
lf_num_sampling_steps=10,
hf_num_sampling_steps=10,
lf_codebook_size=1024,
hf_codebook_size=1024,
transformer_embedding_dim=128,
lf_prior_model_config=PriorModelConfig(
hidden_dim=128,
n_layers=4,
heads=2,
ff_mult=1,
use_rmsnorm=True,
p_unconditional=0.2,
model_dropout=0.3,
emb_dropout=0.3,
),
hf_prior_model_config=PriorModelConfig(
hidden_dim=32,
n_layers=1,
heads=1,
ff_mult=1,
use_rmsnorm=True,
p_unconditional=0.2,
model_dropout=0.3,
emb_dropout=0.3,
),
classifier_free_guidance_scale=1.0,
n_classes=2,
).to(device)
# ---------------------------------------------------------
# 1) Training logic example: dataclass loss from compute_mask_prediction_loss
# ---------------------------------------------------------
maskgit.train()
x = torch.randn(4, 1, 128, device=device) # (batch, channels, length)
class_condition = torch.randint(0, 2, (4, 1), device=device) # (batch, 1)
# maskgit.forward(...) internally calls training_logic.compute_mask_prediction_loss(...)
losses = maskgit(x, class_condition)
print(
losses.total_mask_prediction_loss.item(),
losses.mask_pred_loss_l.item(),
losses.mask_pred_loss_h.item(),
)
# ---------------------------------------------------------
# 2) Sampling logic example: iterative_decoding + token decoding
# ---------------------------------------------------------
maskgit.eval()
with torch.no_grad():
token_ids_l, token_ids_h = maskgit.iterative_decoding(
num_samples=4,
mode="cosine",
class_condition=1, # int or tensor; normalized to (num_samples, 1)
device=device,
)
x_l = maskgit.decode_token_ind_to_timeseries(token_ids_l, frequency="lf") # (4, 1, 128)
x_h = maskgit.decode_token_ind_to_timeseries(token_ids_h, frequency="hf") # (4, 1, 128)
x_gen = x_l + x_h
print(token_ids_l.shape, token_ids_h.shape)
print(x_l.shape, x_h.shape, x_gen.shape)
Google Colab
(NB! make sure to change your notebook setting to GPU.)
A Google Colab notebook is available for time series generation with the pretrained VQVAE. The usage is simple:
- User Settings: specify
dataset_nameandn_samples_to_generate. - Sampling: Run the unconditional sampling and class-conditional sampling.
Related Papers
Neural Mapper for Vector Quantized Time Series Generator (NM-VQTSG)
If you want to improve realism of generated time series while preserving context, please see our Neural Mapper paper:
- Paper: https://arxiv.org/abs/2501.17553
- Citation entry: [3] below
TimeVQVAE for Anomaly Detection (TimeVQVAE-AD)
If your focus is anomaly detection with explainability and counterfactual sampling, please see TimeVQVAE-AD:
- Paper: https://www.sciencedirect.com/science/article/pii/S0031320324008216
- Code: https://github.com/ML4ITS/TimeVQVAE-AnomalyDetection
- Citation entry: [4] below
References
[1] Lee, Daesoo, Sara Malacarne, and Erlend Aune. "Vector Quantized Time Series Generation with a Bidirectional Prior Model." International Conference on Artificial Intelligence and Statistics. PMLR, 2023.
[3] Lee, Daesoo, Sara Malacarne, and Erlend Aune. "Closing the Gap Between Synthetic and Ground Truth Time Series Distributions via Neural Mapping." arXiv preprint arXiv:2501.17553 (2025).
[4] Lee, Daesoo, Sara Malacarne, and Erlend Aune. "Explainable time series anomaly detection using masked latent generative modeling." Pattern Recognition (2024): 110826.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file timevqvae-1.0.1.tar.gz.
File metadata
- Download URL: timevqvae-1.0.1.tar.gz
- Upload date:
- Size: 25.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.2 {"installer":{"name":"uv","version":"0.10.2","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
afc6f15ea8bcf71f4efd8caa3e00c05a9f0e4c78f11610dd98fccc48cdfe266e
|
|
| MD5 |
54060b84058cadd60f0d97d1f40f850d
|
|
| BLAKE2b-256 |
159e1cd4e7419f93906095aa4c46861bc8d622d95a1a10bfff3c9b2b8bfa449c
|
File details
Details for the file timevqvae-1.0.1-py3-none-any.whl.
File metadata
- Download URL: timevqvae-1.0.1-py3-none-any.whl
- Upload date:
- Size: 26.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.2 {"installer":{"name":"uv","version":"0.10.2","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bdb1e1fbee59d028dc2cab46c5d0df14155935ca2eaf0ff46ec1e025bdaaba55
|
|
| MD5 |
d5a05c6d30f2e5e47777f2874b829a25
|
|
| BLAKE2b-256 |
49d048c4a36c9d1370643a97ca51ab7bafd472c73d8953a870a1dec14d07af70
|