Skip to main content

Vector-quantized time-series generation with a bidirectional prior model.

Project description

TimeVQVAE

This is an official Github repository for the PyTorch implementation of TimeVQVAE from our paper "Vector Quantized Time Series Generation with a Bidirectional Prior Model", AISTATS 2023.

TimeVQVAE is a robust time series generation model that utilizes vector quantization for data compression into the discrete latent space (stage1) and a bidirectional transformer for the prior learning (stage2).

Installation

Install from PyPI with uv:

uv add timevqvae

Usage

Stage 1

Example of running VQVAE with a dummy 1D time-series input (batch, channels, length):

import torch
from timevqvae.vqvae import VQVAE

vqvae = VQVAE(
    in_channels=1,
    input_length=128,
    n_fft=4,
    init_dim=4,
    hid_dim=128,
    downsampled_width_l=8,
    downsampled_width_h=32,
    encoder_n_resnet_blocks=2,
    decoder_n_resnet_blocks=2,
    codebook_size_l=1024,
    codebook_size_h=1024,
    kmeans_init=True,
    codebook_dim=8,
)

x = torch.randn(4, 1, 128)  # (batch, channels, length)
out = vqvae(x)

print(out.x_recon.shape)          # (4, 1, 128)
print(out.recons_loss.keys())     # dict_keys(['LF.time', 'HF.time'])
print(out.vq_losses.keys())       # dict_keys(['LF', 'HF'])
print(out.perplexities.keys())    # dict_keys(['LF', 'HF'])

Stage 2

Example of running MaskGIT for:

  1. training loss computation (internally calls compute_mask_prediction_loss)
  2. token sampling with iterative_decoding and decoding sampled tokens back to time series.
import torch
from timevqvae.maskgit import MaskGIT, PriorModelConfig

device = "cuda" if torch.cuda.is_available() else "cpu"

# Reuse `vqvae` from the Stage-1 example above.
# For Stage-2, this should be a pretrained Stage-1 model with trained weights loaded.
# Example:
# vqvae.load_state_dict(torch.load("vqvae_stage1.pt", map_location=device))
# vqvae = vqvae.to(device).eval()

# Stage-2 prior model.
# n_classes is dataset-dependent (example: 2 classes).
maskgit = MaskGIT(
    vqvae=vqvae,
    lf_choice_temperature=10.0,
    hf_choice_temperature=0.0,
    lf_num_sampling_steps=10,
    hf_num_sampling_steps=10,
    lf_codebook_size=1024,
    hf_codebook_size=1024,
    transformer_embedding_dim=128,
    lf_prior_model_config=PriorModelConfig(
        hidden_dim=128,
        n_layers=4,
        heads=2,
        ff_mult=1,
        use_rmsnorm=True,
        p_unconditional=0.2,
        model_dropout=0.3,
        emb_dropout=0.3,
    ),
    hf_prior_model_config=PriorModelConfig(
        hidden_dim=32,
        n_layers=1,
        heads=1,
        ff_mult=1,
        use_rmsnorm=True,
        p_unconditional=0.2,
        model_dropout=0.3,
        emb_dropout=0.3,
    ),
    classifier_free_guidance_scale=1.0,
    n_classes=2,
).to(device)

# ---------------------------------------------------------
# 1) Training logic example: dataclass loss from compute_mask_prediction_loss
# ---------------------------------------------------------
maskgit.train()
x = torch.randn(4, 1, 128, device=device)                      # (batch, channels, length)
class_condition = torch.randint(0, 2, (4, 1), device=device)   # (batch, 1)

# maskgit.forward(...) internally calls training_logic.compute_mask_prediction_loss(...)
losses = maskgit(x, class_condition)
print(
    losses.total_mask_prediction_loss.item(),
    losses.mask_pred_loss_l.item(),
    losses.mask_pred_loss_h.item(),
)

# ---------------------------------------------------------
# 2) Sampling logic example: iterative_decoding + token decoding
# ---------------------------------------------------------
maskgit.eval()
with torch.no_grad():
    token_ids_l, token_ids_h = maskgit.iterative_decoding(
        num_samples=4,
        mode="cosine",
        class_condition=1,   # int or tensor; normalized to (num_samples, 1)
        device=device,
    )

    x_l = maskgit.decode_token_ind_to_timeseries(token_ids_l, frequency="lf")  # (4, 1, 128)
    x_h = maskgit.decode_token_ind_to_timeseries(token_ids_h, frequency="hf")  # (4, 1, 128)
    x_gen = x_l + x_h

print(token_ids_l.shape, token_ids_h.shape)
print(x_l.shape, x_h.shape, x_gen.shape)

Google Colab

Google Colab (NB! make sure to change your notebook setting to GPU.)

A Google Colab notebook is available for time series generation with the pretrained VQVAE. The usage is simple:

  1. User Settings: specify dataset_name and n_samples_to_generate.
  2. Sampling: Run the unconditional sampling and class-conditional sampling.

Related Papers

Neural Mapper for Vector Quantized Time Series Generator (NM-VQTSG)

If you want to improve realism of generated time series while preserving context, please see our Neural Mapper paper:

TimeVQVAE for Anomaly Detection (TimeVQVAE-AD)

If your focus is anomaly detection with explainability and counterfactual sampling, please see TimeVQVAE-AD:

References

[1] Lee, Daesoo, Sara Malacarne, and Erlend Aune. "Vector Quantized Time Series Generation with a Bidirectional Prior Model." International Conference on Artificial Intelligence and Statistics. PMLR, 2023.

[3] Lee, Daesoo, Sara Malacarne, and Erlend Aune. "Closing the Gap Between Synthetic and Ground Truth Time Series Distributions via Neural Mapping." arXiv preprint arXiv:2501.17553 (2025).

[4] Lee, Daesoo, Sara Malacarne, and Erlend Aune. "Explainable time series anomaly detection using masked latent generative modeling." Pattern Recognition (2024): 110826.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

timevqvae-1.0.1.tar.gz (25.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

timevqvae-1.0.1-py3-none-any.whl (26.1 kB view details)

Uploaded Python 3

File details

Details for the file timevqvae-1.0.1.tar.gz.

File metadata

  • Download URL: timevqvae-1.0.1.tar.gz
  • Upload date:
  • Size: 25.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.2 {"installer":{"name":"uv","version":"0.10.2","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for timevqvae-1.0.1.tar.gz
Algorithm Hash digest
SHA256 afc6f15ea8bcf71f4efd8caa3e00c05a9f0e4c78f11610dd98fccc48cdfe266e
MD5 54060b84058cadd60f0d97d1f40f850d
BLAKE2b-256 159e1cd4e7419f93906095aa4c46861bc8d622d95a1a10bfff3c9b2b8bfa449c

See more details on using hashes here.

File details

Details for the file timevqvae-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: timevqvae-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 26.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.2 {"installer":{"name":"uv","version":"0.10.2","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for timevqvae-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bdb1e1fbee59d028dc2cab46c5d0df14155935ca2eaf0ff46ec1e025bdaaba55
MD5 d5a05c6d30f2e5e47777f2874b829a25
BLAKE2b-256 49d048c4a36c9d1370643a97ca51ab7bafd472c73d8953a870a1dec14d07af70

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page