Skip to main content

Vector-quantized time-series generation with a bidirectional prior model.

Project description

TimeVQVAE

This is an official Github repository for the PyTorch implementation of TimeVQVAE from our paper "Vector Quantized Time Series Generation with a Bidirectional Prior Model", AISTATS 2023.

TimeVQVAE is a robust time series generation model that utilizes vector quantization for data compression into the discrete latent space (stage1) and a bidirectional transformer for the prior learning (stage2).

Installation

Install from PyPI with uv:

uv add timevqvae

Usage

Stage 1

Example of running VQVAE with a dummy 1D time-series input (batch, channels, length):

import torch
from timevqvae.vqvae import VQVAE

model = VQVAE(
    in_channels=1,
    input_length=128,
    n_fft=4,
    init_dim=4,
    hid_dim=128,
    downsampled_width_l=8,
    downsampled_width_h=32,
    encoder_n_resnet_blocks=2,
    decoder_n_resnet_blocks=2,
    codebook_size_l=1024,
    codebook_size_h=1024,
    kmeans_init=True,
    codebook_dim=8,
)

x = torch.randn(4, 1, 128)  # (batch, channels, length)
out = model(x)

print(out.x_recon.shape)          # (4, 1, 128)
print(out.recons_loss.keys())     # dict_keys(['LF.time', 'HF.time'])
print(out.vq_losses.keys())       # dict_keys(['LF', 'HF'])
print(out.perplexities.keys())    # dict_keys(['LF', 'HF'])

Stage 2

Example of running MaskGIT for:

  1. training loss computation (internally calls compute_mask_prediction_loss)
  2. token sampling with iterative_decoding and decoding sampled tokens back to time series.
import torch
from timevqvae.maskgit import MaskGIT, PriorModelConfig

device = "cuda" if torch.cuda.is_available() else "cpu"

# Reuse `vqvae` from the Stage-1 example above.
# For Stage-2, this should be a pretrained Stage-1 model with trained weights loaded.
# Example:
# vqvae.load_state_dict(torch.load("vqvae_stage1.pt", map_location=device))
# vqvae = vqvae.to(device).eval()

# Stage-2 prior model.
# n_classes is dataset-dependent (example: 2 classes).
maskgit = MaskGIT(
    vqvae=vqvae,
    lf_choice_temperature=10.0,
    hf_choice_temperature=0.0,
    lf_num_sampling_steps=10,
    hf_num_sampling_steps=10,
    lf_codebook_size=1024,
    hf_codebook_size=1024,
    transformer_embedding_dim=128,
    lf_prior_model_config=PriorModelConfig(
        hidden_dim=128,
        n_layers=4,
        heads=2,
        ff_mult=1,
        use_rmsnorm=True,
        p_unconditional=0.2,
        model_dropout=0.3,
        emb_dropout=0.3,
    ),
    hf_prior_model_config=PriorModelConfig(
        hidden_dim=32,
        n_layers=1,
        heads=1,
        ff_mult=1,
        use_rmsnorm=True,
        p_unconditional=0.2,
        model_dropout=0.3,
        emb_dropout=0.3,
    ),
    classifier_free_guidance_scale=1.0,
    n_classes=2,
).to(device)

# ---------------------------------------------------------
# 1) Training logic example: dataclass loss from compute_mask_prediction_loss
# ---------------------------------------------------------
maskgit.train()
x = torch.randn(4, 1, 128, device=device)                      # (batch, channels, length)
class_condition = torch.randint(0, 2, (4, 1), device=device)   # (batch, 1)

# maskgit.forward(...) internally calls training_logic.compute_mask_prediction_loss(...)
losses = maskgit(x, class_condition)
print(
    losses.total_mask_prediction_loss.item(),
    losses.mask_pred_loss_l.item(),
    losses.mask_pred_loss_h.item(),
)

# ---------------------------------------------------------
# 2) Sampling logic example: iterative_decoding + token decoding
# ---------------------------------------------------------
maskgit.eval()
with torch.no_grad():
    token_ids_l, token_ids_h = maskgit.iterative_decoding(
        num_samples=4,
        mode="cosine",
        class_condition=1,   # int or tensor; normalized to (num_samples, 1)
        device=device,
    )

    x_l = maskgit.decode_token_ind_to_timeseries(token_ids_l, frequency="lf")  # (4, 1, 128)
    x_h = maskgit.decode_token_ind_to_timeseries(token_ids_h, frequency="hf")  # (4, 1, 128)
    x_gen = x_l + x_h

print(token_ids_l.shape, token_ids_h.shape)
print(x_l.shape, x_h.shape, x_gen.shape)

Google Colab

Google Colab (NB! make sure to change your notebook setting to GPU.)

A Google Colab notebook is available for time series generation with the pretrained VQVAE. The usage is simple:

  1. User Settings: specify dataset_name and n_samples_to_generate.
  2. Sampling: Run the unconditional sampling and class-conditional sampling.

Related Papers

Neural Mapper for Vector Quantized Time Series Generator (NM-VQTSG)

If you want to improve realism of generated time series while preserving context, please see our Neural Mapper paper:

TimeVQVAE for Anomaly Detection (TimeVQVAE-AD)

If your focus is anomaly detection with explainability and counterfactual sampling, please see TimeVQVAE-AD:

Citation

[1] Lee, Daesoo, Sara Malacarne, and Erlend Aune. "Vector Quantized Time Series Generation with a Bidirectional Prior Model." International Conference on Artificial Intelligence and Statistics. PMLR, 2023.

[3] Lee, Daesoo, Sara Malacarne, and Erlend Aune. "Closing the Gap Between Synthetic and Ground Truth Time Series Distributions via Neural Mapping." arXiv preprint arXiv:2501.17553 (2025).

[4] Lee, Daesoo, Sara Malacarne, and Erlend Aune. "Explainable time series anomaly detection using masked latent generative modeling." Pattern Recognition (2024): 110826.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

timevqvae-1.0.0.tar.gz (25.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

timevqvae-1.0.0-py3-none-any.whl (26.2 kB view details)

Uploaded Python 3

File details

Details for the file timevqvae-1.0.0.tar.gz.

File metadata

  • Download URL: timevqvae-1.0.0.tar.gz
  • Upload date:
  • Size: 25.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.2 {"installer":{"name":"uv","version":"0.10.2","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for timevqvae-1.0.0.tar.gz
Algorithm Hash digest
SHA256 0cdc27f7dbb184160f7d335a838ee38fdc69dbc475b3274ee58a26173e44c7dc
MD5 5d0ef7f77558fc3d37337dafafb5df0f
BLAKE2b-256 0d756215f7e9edc3f38de5e27bc8dc791f40d8c2be600c3afee787a8e6385778

See more details on using hashes here.

File details

Details for the file timevqvae-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: timevqvae-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 26.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.2 {"installer":{"name":"uv","version":"0.10.2","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for timevqvae-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a8061c681a519d27a75a0dba2b34edd9bd435aa09b81256616747d4c5180d6c2
MD5 5cf41e61865876018c05eb066cc39072
BLAKE2b-256 6c4ea72a98fcc297a162776f36b7befe320d02b2b273623e0884901de0d007d3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page