PyTorch Implementation of an N-Dimensional VQ-VAE by AdityaNG
Project description
nd_vq_vae
PyTorch Implementation of the N-Dimensional VQ-VAE by AdityaNG
Install it from PyPI
pip install nd_vq_vae
Cite
Cite our work if you find it useful!
@article{NG2024D3Nav,
title={D³Nav: Data-Driven Driving Agents for Autonomous Vehicles in Unstructured Traffic},
author={Aditya NG and Gowri Srinivas},
journal={The 35th British Machine Vision Conference (BMVC)},
year={2024},
url={https://bmvc2024.org/}
}
Usage
Below is an example of encoding temporal video data. Video data is 3D since it spans height and width as well as time. Note that the channels are each treated separately and does not count as a dimension.
from nd_vq_vae import NDimVQVAE
sequence_length = 3
channels = 3
res = (128, 256)
input_shape = (channels, sequence_length, res[0], res[1])
model = NDimVQVAE(
embedding_dim=64,
n_codes=64,
n_dims=3,
downsample=args.downsample,
n_hiddens=64,
n_res_layers=2,
codebook_beta=0.10,
input_shape=input_shape,
)
x = torch.randn(batch_size, *input_shape)
recon_loss, x_recon, vq_output = model(x)
3D: Train on Videos
Videos are 3 dimensional data with (Time, Height, Width).
You can construct a video dataset at data/video_dataset/
as follows:
$ tree data/video_dataset/
data/video_dataset/
├── test
│ ├── Gu1D3BnIYZg.mkv # you can add more videos to both folders
└── train
└── ceEE_oYuzS4.mp4
2D: Train on Images
Videos are 2 dimensional data with (Height, Width).
You can construct a video dataset at data/image_dataset/
as follows:
$ tree data/image_dataset/
data/image_dataset/
├── test
│ ├── 0000001.png # you can add more images to both folders
└── train
└── 0000001.png
Then you can use the video training script:
python scripts/train_image.py --data_path data/image_dataset/
1D
Coming soon!
Hyperparameters: How to tune my VQ-VAE?
The VQ-VAE has the following hyperparameters
- Codebook size and embedding dimension (n_codes and embedding_dim)
- Model capacity (n_hiddens and n_res_layers)
- Downsampling strategy (downsample)
- Loss balancing (codebook_beta and recon_loss_factor)
- Optimization parameters (learning_rate, beta1, beta2)
- Training parameters (batch_size, num_epochs)
- Attention mechanism (n_head, attn_dropout)
- Codebook update strategy (ema_decay)
Below is how you would tune these parameters based on the recon_loss
and commitment_loss
curves.
- Monitor Key Metrics
- Track these metrics during training and validation:
- Reconstruction Loss
- Commitment Loss
- Perplexity
- Analyze Loss Curves
- Reconstruction Loss
- High and not decreasing: Increase model capacity (n_hiddens, n_res_layers) or adjust learning rate.
- Train decreasing, val stable: Potential overfitting. Reduce capacity or add regularization.
- Both decreasing, val much higher: Increase batch_size or use data augmentation.
- Commitment Loss
- Too high: Decrease codebook_beta.
- Too low or unstable: Increase codebook_beta.
- Balance Losses
- Adjust codebook_beta and recon_loss_factor to achieve a good balance between reconstruction and commitment losses.
- Optimize Codebook Usage
- Monitor perplexity:
- Low perplexity: Increase n_codes or decrease embedding_dim.
- High perplexity: Decrease n_codes or increase embedding_dim.
- Fine-tune Learning Dynamics
- Slow convergence: Increase learning_rate or adjust optimizer parameters.
- Unstable training: Decrease learning_rate or increase batch_size.
- Address Overfitting
- If validation loss plateaus while training loss decreases:
- Introduce dropout in encoder/decoder
- Reduce model capacity
- Increase batch_size or use data augmentation
- Attention Mechanism
- Adjust n_head and attn_dropout in attention blocks for better long-range dependencies.
- Codebook Update Strategy
- Fine-tune ema_decay for codebook stability and adaptation speed.
- Downsampling Strategy
- Adjust downsample factors based on computational resources and required detail level.
Best Practices
- Make incremental changes to hyperparameters.
- Perform ablation studies, changing one parameter at a time.
- Consider using learning rate scheduling or cyclical learning rates.
- Regularly save checkpoints and log experiments for comparison.
Development
Read the CONTRIBUTING.md file.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file nd_vq_vae-1.0.0.tar.gz
.
File metadata
- Download URL: nd_vq_vae-1.0.0.tar.gz
- Upload date:
- Size: 16.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 71840d071e6ff417af2b5d86a2259fe13c50bb651bac641b55df6f63953c8cbe |
|
MD5 | 91a5571da314d5ddf4bab0ed5c1c5a46 |
|
BLAKE2b-256 | ce26b126d46cd460524dabea14fc9b86978a6efe273c0ad9e4c88f6e8025885a |
Provenance
File details
Details for the file nd_vq_vae-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: nd_vq_vae-1.0.0-py3-none-any.whl
- Upload date:
- Size: 11.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 250f152fe77eaf112362201617184b0126f5bf8bfecebb4346c0775359a7df2b |
|
MD5 | af3867e903b1044153b69438282c9877 |
|
BLAKE2b-256 | 21e2500f7ee1425f3d97768b2a84d4385a03e99a4899b4c1753dd641dde239b3 |