Skip to main content

PyTorch Implementation of an N-Dimensional VQ-VAE by AdityaNG

Project description

NDimVQVAE

PyPI - Downloads PyPI - Version codecov CI GitHub License

PyTorch Implementation of the N-Dimensional VQ-VAE by AdityaNG.

VQ-VAE is a powerful technique for learning discrete representations of high-dimensional data. NDimVQVAE extends this concept to support N-dimensional data, making it ideal for building foundation models that process various types of sensor data.

Everyone is looking to build their own foundation model! Multimodal foundation models require us to condense high-dimensional sensor data into a low-dimensional quantized space. That's where NDimVQVAE comes in!

NDimVQVAE provides support to encode 1D, 2D, 3D (and 4D data if you include channels)!

If you want to build a foundation model that processes video data, point cloud data or signals from a Tokamak or Stellarator for magnetic confinement, this is your one-stop shop to build an Encoder-Decoder pair!

Architecture

Install it from PyPI

pip install nd_vq_vae

Cite

D³Nav pilot studies made use of VQ-VAE encoders to feed in video data (3D), image data (2D) and trajectory data (1D) into our transformer model. We make our NDimVQVAE source code public. Cite our work if you find it useful!

@article{NG2024D3Nav,
  title={D³Nav: Data-Driven Driving Agents for Autonomous Vehicles in Unstructured Traffic},
  author={Aditya NG and Gowri Srinivas},
  journal={The 35th British Machine Vision Conference (BMVC)},
  year={2024},
  url={https://bmvc2024.org/}
}

Usage

Below is an example of encoding temporal video data. Video data is 3D since it spans height and width as well as time. Note that the channels are each treated separately and does not count as a dimension.

from nd_vq_vae import NDimVQVAE

sequence_length = 3
channels = 3
res = (128, 256)
input_shape = (channels, sequence_length, res[0], res[1])

model = NDimVQVAE(
    embedding_dim=64,
    n_codes=64,
    n_dims=3,
    downsample=args.downsample,
    n_hiddens=64,
    n_res_layers=2,
    codebook_beta=0.10,
    input_shape=input_shape,
)

x = torch.randn(batch_size, *input_shape)
recon_loss, x_recon, vq_output = model(x)

3D: Train on Videos

Videos are 3 dimensional data with (Time, Height, Width). You can construct a video dataset at data/video_dataset/ as follows:

$ tree data/video_dataset/
data/video_dataset/
├── test   ├── Gu1D3BnIYZg.mkv  # you can add more videos to both folders
└── train
    └── ceEE_oYuzS4.mp4

2D: Train on Images

Videos are 2 dimensional data with (Height, Width). You can construct a video dataset at data/image_dataset/ as follows:

$ tree data/image_dataset/
data/image_dataset/
├── test   ├── 0000001.png  # you can add more images to both folders
└── train
    └── 0000001.png

Then you can use the video training script:

python scripts/train_image.py --data_path data/image_dataset/

1D

Coming soon!

Hyperparameters: How to tune my VQ-VAE?

The VQ-VAE has the following hyperparameters

  1. Codebook size and embedding dimension (n_codes and embedding_dim)
  2. Model capacity (n_hiddens and n_res_layers)
  3. Downsampling strategy (downsample)
  4. Loss balancing (codebook_beta and recon_loss_factor)
  5. Optimization parameters (learning_rate, beta1, beta2)
  6. Training parameters (batch_size, num_epochs)
  7. Attention mechanism (n_head, attn_dropout)
  8. Codebook update strategy (ema_decay)

Below is how you would tune these parameters based on the recon_loss and commitment_loss curves.

  1. Monitor Key Metrics
  • Track these metrics during training and validation:
    • Reconstruction Loss
    • Commitment Loss
    • Perplexity
    • Note that the reconstruction loss are in the same units as the input data. Say the input data is in meters, that means that the reconstruction loss will also be in meters!
  1. Analyze Loss Curves
  • Reconstruction Loss
    • High and not decreasing: Increase model capacity (n_hiddens, n_res_layers) or adjust learning rate.
    • Train decreasing, val stable: Potential overfitting. Reduce capacity or add regularization.
    • Both decreasing, val much higher: Increase batch_size or use data augmentation.
  • Commitment Loss
    • Too high: Decrease codebook_beta.
    • Too low or unstable: Increase codebook_beta.
  1. Balance Losses
  • Adjust codebook_beta and recon_loss_factor to achieve a good balance between reconstruction and commitment losses.
  1. Optimize Codebook Usage
  • Monitor perplexity:
    • Low perplexity: Increase n_codes or decrease embedding_dim.
    • High perplexity: Decrease n_codes or increase embedding_dim.
  1. Fine-tune Learning Dynamics
  • Slow convergence: Increase learning_rate or adjust optimizer parameters.
  • Unstable training: Decrease learning_rate or increase batch_size.
  1. Address Overfitting
  • If validation loss plateaus while training loss decreases:
    • Introduce dropout in encoder/decoder
    • Reduce model capacity
    • Increase batch_size or use data augmentation
  1. Attention Mechanism
  • Adjust n_head and attn_dropout in attention blocks for better long-range dependencies.
  1. Codebook Update Strategy
  • Fine-tune ema_decay for codebook stability and adaptation speed.
  1. Downsampling Strategy
  • Adjust downsample factors based on computational resources and required detail level.

Best Practices

  • Make incremental changes to hyperparameters.
  • Perform ablation studies, changing one parameter at a time.
  • Consider using learning rate scheduling or cyclical learning rates.
  • Regularly save checkpoints and log experiments for comparison.

Development

Read the CONTRIBUTING.md file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nd_vq_vae-1.0.1.tar.gz (20.9 kB view details)

Uploaded Source

Built Distribution

nd_vq_vae-1.0.1-py3-none-any.whl (15.2 kB view details)

Uploaded Python 3

File details

Details for the file nd_vq_vae-1.0.1.tar.gz.

File metadata

  • Download URL: nd_vq_vae-1.0.1.tar.gz
  • Upload date:
  • Size: 20.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for nd_vq_vae-1.0.1.tar.gz
Algorithm Hash digest
SHA256 ec6195b9dab154f2ff4077f37fae3d72b59491e953018692a4ee4464ca4aa79d
MD5 a28d7f0fca8092aadeb3768cc7174987
BLAKE2b-256 4393bf733862fffa5f7fdce08cb8fb4b192c36076ed5db725e4e64829d1b24a5

See more details on using hashes here.

Provenance

File details

Details for the file nd_vq_vae-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: nd_vq_vae-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 15.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for nd_vq_vae-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6ea2fb8b8658326964518779952fd562b28d0c6efcb3bd4497e3ab96bd7d96f6
MD5 904f0703475f03b6f474ad656292301f
BLAKE2b-256 6b8b2e3a29116573b5aad452a4c6421592e4e22bd6702165537f0e49120804d7

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page