Skip to main content

PyTorch Implementation of an N-Dimensional VQ-VAE by AdityaNG

Project description

nd_vq_vae

codecov CI

PyTorch Implementation of the N-Dimensional VQ-VAE by AdityaNG

Architecture

Install it from PyPI

pip install nd_vq_vae

Cite

Cite our work if you find it useful!

@article{NG2024D3Nav,
  title={D³Nav: Data-Driven Driving Agents for Autonomous Vehicles in Unstructured Traffic},
  author={Aditya NG and Gowri Srinivas},
  journal={The 35th British Machine Vision Conference (BMVC)},
  year={2024},
  url={https://bmvc2024.org/}
}

Usage

Below is an example of encoding temporal video data. Video data is 3D since it spans height and width as well as time. Note that the channels are each treated separately and does not count as a dimension.

from nd_vq_vae import NDimVQVAE

sequence_length = 3
channels = 3
res = (128, 256)
input_shape = (channels, sequence_length, res[0], res[1])

model = NDimVQVAE(
    embedding_dim=64,
    n_codes=64,
    n_dims=3,
    downsample=args.downsample,
    n_hiddens=64,
    n_res_layers=2,
    codebook_beta=0.10,
    input_shape=input_shape,
)

x = torch.randn(batch_size, *input_shape)
recon_loss, x_recon, vq_output = model(x)

3D: Train on Videos

Videos are 3 dimensional data with (Time, Height, Width). You can construct a video dataset at data/video_dataset/ as follows:

$ tree data/video_dataset/
data/video_dataset/
├── test   ├── Gu1D3BnIYZg.mkv  # you can add more videos to both folders
└── train
    └── ceEE_oYuzS4.mp4

2D: Train on Images

Videos are 2 dimensional data with (Height, Width). You can construct a video dataset at data/image_dataset/ as follows:

$ tree data/image_dataset/
data/image_dataset/
├── test   ├── 0000001.png  # you can add more images to both folders
└── train
    └── 0000001.png

Then you can use the video training script:

python scripts/train_image.py --data_path data/image_dataset/

1D

Coming soon!

Hyperparameters: How to tune my VQ-VAE?

The VQ-VAE has the following hyperparameters

  1. Codebook size and embedding dimension (n_codes and embedding_dim)
  2. Model capacity (n_hiddens and n_res_layers)
  3. Downsampling strategy (downsample)
  4. Loss balancing (codebook_beta and recon_loss_factor)
  5. Optimization parameters (learning_rate, beta1, beta2)
  6. Training parameters (batch_size, num_epochs)
  7. Attention mechanism (n_head, attn_dropout)
  8. Codebook update strategy (ema_decay)

Below is how you would tune these parameters based on the recon_loss and commitment_loss curves.

  1. Monitor Key Metrics
  • Track these metrics during training and validation:
    • Reconstruction Loss
    • Commitment Loss
    • Perplexity
  1. Analyze Loss Curves
  • Reconstruction Loss
    • High and not decreasing: Increase model capacity (n_hiddens, n_res_layers) or adjust learning rate.
    • Train decreasing, val stable: Potential overfitting. Reduce capacity or add regularization.
    • Both decreasing, val much higher: Increase batch_size or use data augmentation.
  • Commitment Loss
    • Too high: Decrease codebook_beta.
    • Too low or unstable: Increase codebook_beta.
  1. Balance Losses
  • Adjust codebook_beta and recon_loss_factor to achieve a good balance between reconstruction and commitment losses.
  1. Optimize Codebook Usage
  • Monitor perplexity:
    • Low perplexity: Increase n_codes or decrease embedding_dim.
    • High perplexity: Decrease n_codes or increase embedding_dim.
  1. Fine-tune Learning Dynamics
  • Slow convergence: Increase learning_rate or adjust optimizer parameters.
  • Unstable training: Decrease learning_rate or increase batch_size.
  1. Address Overfitting
  • If validation loss plateaus while training loss decreases:
    • Introduce dropout in encoder/decoder
    • Reduce model capacity
    • Increase batch_size or use data augmentation
  1. Attention Mechanism
  • Adjust n_head and attn_dropout in attention blocks for better long-range dependencies.
  1. Codebook Update Strategy
  • Fine-tune ema_decay for codebook stability and adaptation speed.
  1. Downsampling Strategy
  • Adjust downsample factors based on computational resources and required detail level.

Best Practices

  • Make incremental changes to hyperparameters.
  • Perform ablation studies, changing one parameter at a time.
  • Consider using learning rate scheduling or cyclical learning rates.
  • Regularly save checkpoints and log experiments for comparison.

Development

Read the CONTRIBUTING.md file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nd_vq_vae-1.0.0.tar.gz (16.5 kB view details)

Uploaded Source

Built Distribution

nd_vq_vae-1.0.0-py3-none-any.whl (11.4 kB view details)

Uploaded Python 3

File details

Details for the file nd_vq_vae-1.0.0.tar.gz.

File metadata

  • Download URL: nd_vq_vae-1.0.0.tar.gz
  • Upload date:
  • Size: 16.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for nd_vq_vae-1.0.0.tar.gz
Algorithm Hash digest
SHA256 71840d071e6ff417af2b5d86a2259fe13c50bb651bac641b55df6f63953c8cbe
MD5 91a5571da314d5ddf4bab0ed5c1c5a46
BLAKE2b-256 ce26b126d46cd460524dabea14fc9b86978a6efe273c0ad9e4c88f6e8025885a

See more details on using hashes here.

Provenance

File details

Details for the file nd_vq_vae-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: nd_vq_vae-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 11.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for nd_vq_vae-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 250f152fe77eaf112362201617184b0126f5bf8bfecebb4346c0775359a7df2b
MD5 af3867e903b1044153b69438282c9877
BLAKE2b-256 21e2500f7ee1425f3d97768b2a84d4385a03e99a4899b4c1753dd641dde239b3

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page