Skip to main content

PyTorch Implementation of an N-Dimensional VQ-VAE by AdityaNG

Project description

NDimVQVAE

PyPI - Downloads PyPI - Version codecov CI GitHub License

PyTorch Implementation of the N-Dimensional VQ-VAE by AdityaNG.

Architecture

VQ-VAE is a powerful technique for learning discrete representations of high-dimensional data. NDimVQVAE extends this concept to support N-dimensional data, making it ideal for building foundation models that process various types of sensor data.

Everyone is looking to build their own foundation model! Multimodal foundation models require us to condense high-dimensional sensor data into a low-dimensional quantized space. That's where NDimVQVAE comes in!

NDimVQVAE provides support to encode 1D, 2D, 3D (and 4D data if you include channels)!

If you want to build a foundation model that processes video data, point cloud data or signals from a Tokamak or Stellarator for magnetic confinement, this is your one-stop shop to build an Encoder-Decoder pair!

Install it from PyPI

pip install nd_vq_vae

Cite

D³Nav pilot studies made use of VQ-VAE encoders to feed in video data (3D), image data (2D) and trajectory data (1D) into our transformer model. We make our NDimVQVAE source code public. Cite our work if you find it useful!

@article{NG2024D3Nav,
  title={D³Nav: Data-Driven Driving Agents for Autonomous Vehicles in Unstructured Traffic},
  author={Aditya NG and Gowri Srinivas},
  journal={The 35th British Machine Vision Conference (BMVC)},
  year={2024},
  url={https://bmvc2024.org/}
}

Usage

Below is an example of encoding temporal video data. Video data is 3D since it spans height and width as well as time. Note that the channels are each treated separately and does not count as a dimension.

from nd_vq_vae import NDimVQVAE

sequence_length = 3
channels = 3
res = (128, 256)
input_shape = (channels, sequence_length, res[0], res[1])

model = NDimVQVAE(
    embedding_dim=64,
    n_codes=64,
    n_dims=3,
    downsample=args.downsample,
    n_hiddens=64,
    n_res_layers=2,
    codebook_beta=0.10,
    input_shape=input_shape,
)

x = torch.randn(batch_size, *input_shape)
recon_loss, x_recon, vq_output = model(x)

3D: Train on Videos

Videos are 3 dimensional data with (Time, Height, Width). You can construct a video dataset at data/video_dataset/ as follows:

$ tree data/video_dataset/
data/video_dataset/
├── test   ├── Gu1D3BnIYZg.mkv  # you can add more videos to both folders
└── train
    └── ceEE_oYuzS4.mp4

2D: Train on Images

Videos are 2 dimensional data with (Height, Width). You can construct a video dataset at data/image_dataset/ as follows:

$ tree data/image_dataset/
data/image_dataset/
├── test   ├── 0000001.png  # you can add more images to both folders
└── train
    └── 0000001.png

Then you can use the video training script:

python scripts/train_image.py --data_path data/image_dataset/

1D

Coming soon!

Hyperparameters: How to tune my VQ-VAE?

The VQ-VAE has the following hyperparameters

  1. Codebook size and embedding dimension (n_codes and embedding_dim)
  2. Model capacity (n_hiddens and n_res_layers)
  3. Downsampling strategy (downsample)
  4. Loss balancing (codebook_beta and recon_loss_factor)
  5. Optimization parameters (learning_rate, beta1, beta2)
  6. Training parameters (batch_size, num_epochs)
  7. Attention mechanism (n_head, attn_dropout)
  8. Codebook update strategy (ema_decay)

Below is how you would tune these parameters based on the recon_loss and commitment_loss curves.

  1. Monitor Key Metrics
  • Track these metrics during training and validation:
    • Reconstruction Loss
    • Commitment Loss
    • Perplexity
    • Note that the reconstruction loss are in the same units as the input data. Say the input data is in meters, that means that the reconstruction loss will also be in meters!
  1. Analyze Loss Curves
  • Reconstruction Loss
    • High and not decreasing: Increase model capacity (n_hiddens, n_res_layers) or adjust learning rate.
    • Train decreasing, val stable: Potential overfitting. Reduce capacity or add regularization.
    • Both decreasing, val much higher: Increase batch_size or use data augmentation.
  • Commitment Loss
    • Too high: Decrease codebook_beta.
    • Too low or unstable: Increase codebook_beta.
  1. Balance Losses
  • Adjust codebook_beta and recon_loss_factor to achieve a good balance between reconstruction and commitment losses.
  1. Optimize Codebook Usage
  • Monitor perplexity:
    • Low perplexity: Increase n_codes or decrease embedding_dim.
    • High perplexity: Decrease n_codes or increase embedding_dim.
  1. Fine-tune Learning Dynamics
  • Slow convergence: Increase learning_rate or adjust optimizer parameters.
  • Unstable training: Decrease learning_rate or increase batch_size.
  1. Address Overfitting
  • If validation loss plateaus while training loss decreases:
    • Introduce dropout in encoder/decoder
    • Reduce model capacity
    • Increase batch_size or use data augmentation
  1. Attention Mechanism
  • Adjust n_head and attn_dropout in attention blocks for better long-range dependencies.
  1. Codebook Update Strategy
  • Fine-tune ema_decay for codebook stability and adaptation speed.
  1. Downsampling Strategy
  • Adjust downsample factors based on computational resources and required detail level.

Best Practices

  • Make incremental changes to hyperparameters.
  • Perform ablation studies, changing one parameter at a time.
  • Consider using learning rate scheduling or cyclical learning rates.
  • Regularly save checkpoints and log experiments for comparison.

Architecture

Architecture

Development

Read the CONTRIBUTING.md file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nd_vq_vae-1.0.3.tar.gz (20.9 kB view details)

Uploaded Source

Built Distribution

nd_vq_vae-1.0.3-py3-none-any.whl (15.2 kB view details)

Uploaded Python 3

File details

Details for the file nd_vq_vae-1.0.3.tar.gz.

File metadata

  • Download URL: nd_vq_vae-1.0.3.tar.gz
  • Upload date:
  • Size: 20.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for nd_vq_vae-1.0.3.tar.gz
Algorithm Hash digest
SHA256 7a5b61c8748ada38ec913df36ea7351fa6046734bf42815b2a5327f5b14d935e
MD5 99185de2660d114061bba889de0c95fd
BLAKE2b-256 3f824ec94531aab01408aa82083683ac5a89e1b202819038c3797b261ab7e955

See more details on using hashes here.

Provenance

File details

Details for the file nd_vq_vae-1.0.3-py3-none-any.whl.

File metadata

  • Download URL: nd_vq_vae-1.0.3-py3-none-any.whl
  • Upload date:
  • Size: 15.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for nd_vq_vae-1.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 29ac9c7ed974edc96af8d8b884ae9ddfba8efdae9fa48e2571439c4b4f739bda
MD5 58f8a4aa28c2cd79d3c66d1be9da7dd4
BLAKE2b-256 f1556ebb06f0ce368c69feda742430b9c08c2a3b06c0ad59284315820e86db94

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page