PyTorch Implementation of an N-Dimensional VQ-VAE by AdityaNG
Project description
NDimVQVAE
PyTorch Implementation of the N-Dimensional VQ-VAE by AdityaNG.
VQ-VAE is a powerful technique for learning discrete representations of high-dimensional data. NDimVQVAE
extends this concept to support N-dimensional data, making it ideal for building foundation models that process various types of sensor data.
Everyone is looking to build their own foundation model! Multimodal foundation models require us to condense high-dimensional sensor data into a low-dimensional quantized space. That's where NDimVQVAE
comes in!
NDimVQVAE
provides support to encode 1D, 2D, 3D (and 4D data if you include channels)!
If you want to build a foundation model that processes video data, point cloud data or signals from a Tokamak or Stellarator for magnetic confinement, this is your one-stop shop to build an Encoder-Decoder pair!
Install it from PyPI
pip install nd_vq_vae
Cite
D³Nav pilot studies made use of VQ-VAE encoders to feed in video data (3D), image data (2D) and trajectory data (1D) into our transformer model. We make our NDimVQVAE
source code public.
Cite our work if you find it useful!
@article{NG2024D3Nav,
title={D³Nav: Data-Driven Driving Agents for Autonomous Vehicles in Unstructured Traffic},
author={Aditya NG and Gowri Srinivas},
journal={The 35th British Machine Vision Conference (BMVC)},
year={2024},
url={https://bmvc2024.org/}
}
Usage
Below is an example of encoding temporal video data. Video data is 3D since it spans height and width as well as time. Note that the channels are each treated separately and does not count as a dimension.
from nd_vq_vae import NDimVQVAE
sequence_length = 3
channels = 3
res = (128, 256)
input_shape = (channels, sequence_length, res[0], res[1])
model = NDimVQVAE(
embedding_dim=64,
n_codes=64,
n_dims=3,
downsample=args.downsample,
n_hiddens=64,
n_res_layers=2,
codebook_beta=0.10,
input_shape=input_shape,
)
x = torch.randn(batch_size, *input_shape)
recon_loss, x_recon, vq_output = model(x)
3D: Train on Videos
Videos are 3 dimensional data with (Time, Height, Width).
You can construct a video dataset at data/video_dataset/
as follows:
$ tree data/video_dataset/
data/video_dataset/
├── test
│ ├── Gu1D3BnIYZg.mkv # you can add more videos to both folders
└── train
└── ceEE_oYuzS4.mp4
2D: Train on Images
Videos are 2 dimensional data with (Height, Width).
You can construct a video dataset at data/image_dataset/
as follows:
$ tree data/image_dataset/
data/image_dataset/
├── test
│ ├── 0000001.png # you can add more images to both folders
└── train
└── 0000001.png
Then you can use the video training script:
python scripts/train_image.py --data_path data/image_dataset/
1D
Coming soon!
Hyperparameters: How to tune my VQ-VAE?
The VQ-VAE has the following hyperparameters
- Codebook size and embedding dimension (n_codes and embedding_dim)
- Model capacity (n_hiddens and n_res_layers)
- Downsampling strategy (downsample)
- Loss balancing (codebook_beta and recon_loss_factor)
- Optimization parameters (learning_rate, beta1, beta2)
- Training parameters (batch_size, num_epochs)
- Attention mechanism (n_head, attn_dropout)
- Codebook update strategy (ema_decay)
Below is how you would tune these parameters based on the recon_loss
and commitment_loss
curves.
- Monitor Key Metrics
- Track these metrics during training and validation:
- Reconstruction Loss
- Commitment Loss
- Perplexity
- Note that the reconstruction loss are in the same units as the input data. Say the input data is in meters, that means that the reconstruction loss will also be in meters!
- Analyze Loss Curves
- Reconstruction Loss
- High and not decreasing: Increase model capacity (n_hiddens, n_res_layers) or adjust learning rate.
- Train decreasing, val stable: Potential overfitting. Reduce capacity or add regularization.
- Both decreasing, val much higher: Increase batch_size or use data augmentation.
- Commitment Loss
- Too high: Decrease codebook_beta.
- Too low or unstable: Increase codebook_beta.
- Balance Losses
- Adjust codebook_beta and recon_loss_factor to achieve a good balance between reconstruction and commitment losses.
- Optimize Codebook Usage
- Monitor perplexity:
- Low perplexity: Increase n_codes or decrease embedding_dim.
- High perplexity: Decrease n_codes or increase embedding_dim.
- Fine-tune Learning Dynamics
- Slow convergence: Increase learning_rate or adjust optimizer parameters.
- Unstable training: Decrease learning_rate or increase batch_size.
- Address Overfitting
- If validation loss plateaus while training loss decreases:
- Introduce dropout in encoder/decoder
- Reduce model capacity
- Increase batch_size or use data augmentation
- Attention Mechanism
- Adjust n_head and attn_dropout in attention blocks for better long-range dependencies.
- Codebook Update Strategy
- Fine-tune ema_decay for codebook stability and adaptation speed.
- Downsampling Strategy
- Adjust downsample factors based on computational resources and required detail level.
Best Practices
- Make incremental changes to hyperparameters.
- Perform ablation studies, changing one parameter at a time.
- Consider using learning rate scheduling or cyclical learning rates.
- Regularly save checkpoints and log experiments for comparison.
Architecture
Development
Read the CONTRIBUTING.md file.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file nd_vq_vae-1.0.3.tar.gz
.
File metadata
- Download URL: nd_vq_vae-1.0.3.tar.gz
- Upload date:
- Size: 20.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7a5b61c8748ada38ec913df36ea7351fa6046734bf42815b2a5327f5b14d935e |
|
MD5 | 99185de2660d114061bba889de0c95fd |
|
BLAKE2b-256 | 3f824ec94531aab01408aa82083683ac5a89e1b202819038c3797b261ab7e955 |
File details
Details for the file nd_vq_vae-1.0.3-py3-none-any.whl
.
File metadata
- Download URL: nd_vq_vae-1.0.3-py3-none-any.whl
- Upload date:
- Size: 15.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 29ac9c7ed974edc96af8d8b884ae9ddfba8efdae9fa48e2571439c4b4f739bda |
|
MD5 | 58f8a4aa28c2cd79d3c66d1be9da7dd4 |
|
BLAKE2b-256 | f1556ebb06f0ce368c69feda742430b9c08c2a3b06c0ad59284315820e86db94 |