A Framework for Training and Evaluating Video Prediction Models
Project description
Introduction
Video prediction ('VP') is the task of predicting future frames given some context frames.
Like with most Computer Vision sub-domains, scientific contributions in this field exhibit a high variance in the following aspects:
- Training protocol (dataset usage, when to backprop, value ranges etc.)
- Technical details of model implementation (deep learning framework, package dependencies etc.)
- Benchmark selection and execution (this includes the choice of dataset, number of context/predicted frames, skipping frames in the observed sequences etc.)
- Evaluation protocol (metrics chosen, variations in implementation/reduction modes, different ways of creating visualizations etc.)
Furthermore, while many contributors nowadays do share their code, seemingly minor missing parts such as dataloaders etc. make it much harder to assess, compare and improve existing models.
This repo aims at providing a suite that facilitates scientific work in the subfield, providing standardized yet customizable solutions for the aspects mentioned above. This way, validating existing VP models and creating new ones hopefully becomes much less tedious.
Installation
Requirements:
python >= 3.6
(code is tested with version3.8
)pip
From PyPi:
pip install vp-suite
From source:
pip install git+https://github.com/Flunzmas/vp-suite.git
If you want to contribute:
git clone https://github.com/Flunzmas/vp-suite.git
cd vp-suite
pip install -e .[dev,doc]
Usage
When using this package, a folder vp-suite
is created in your current working directory/current path
that will contain all downloaded data as well as run logs, outputs and trained models.
Training models
- Set up the trainer:
- Load one of the provided datasets (will be downloaded automatically) create your own:
- Create video prediction model (either from scratch or from a pretrained checkpoint, can be one of the provided models or your own):
- Run the training loop, optionally providing custom configuration
from vp_suite import VPSuite
suite = VPSuite()
suite.load_dataset("MM") # load moving MNIST dataset from default location
model_checkpoint = "" # Set to valid model path to load a checkpoint
if model_checkpoint != "":
suite.load_model(model_checkpoint)
else:
suite.create_model('lstm') # create a ConvLSTM-Based Prediction Model
suite.train(lr=2e-4, epochs=100)
This will train the model, log training progress to the console (and optionally to Weights & Biases), save model checkpoints on improvement and, optionally, generate and save prediction visualizations.
Evaluating models
- Set up the tester
- Load one of the provided datasets or (will be downloaded automatically) create your own
- Load the models you'd like to test (by default, a CopyLastFrame baseline is already loaded)
- Run the testing on all models, optionally providing custom configuration of the evaluation protocol:
from vp_suite import VPSuite
suite = VPSuite()
suite.load_dataset("MM") # load moving MNIST dataset from default location
# get the filepaths to the models you'd like to test
model_dirs = ["out/model_foo/", "out/model_bar/"]
for model_dir in model_dirs:
suite.load_model(model_dir, ckpt_name="best_model.pth")
suite.test(context_frames=5, pred_frames=10)
This code will evaluate the loaded models on the loaded dataset (its test portion, if avaliable), creating detailed summaries of prediction performance across a customizable set of metrics. Optionally, the results as well as prediction visualizations can be saved and logged to Weights & Biases.
Note: if the specified evaluation protocol or the loaded dataset is incompatible with one of the models, this will raise an error with an explanation.
Hyperparameter Optimization
This package uses optuna to provide hyperparameter optimization functionalities. The following snippet provides a full example:
import json
from vp_suite import VPSuite
from vp_suite.constants import PKG_RESOURCES
suite = VPSuite()
suite.load_dataset(dataset="KTH") # select dataset of choice
suite.create_model(model_type="lstm") # select model of choice
with open(str((PKG_RESOURCES / "optuna_example_config.json").resolve()), 'r') as cfg_file:
optuna_cfg = json.load(cfg_file)
# optuna_cfg specifies the parameters' search intervals and scales; modify as you wish.
suite.hyperopt(optuna_cfg, n_trials=30, epochs=10)
This code e.g. will run 30 training loops (called trials by optuna), producing a trained model for each hyperparameter configuration and writing the hyperparameter configuration of the best performing run to the console.
Note 1: For hyperopt, visualization, logging and model checkpointing is minimized to reduce IO strain.
Note 2: Despite optuna's trial pruning capabilities, running a high number of trials might still take a lot of time. In that case, consider e.g. reducing the number of training epochs.
Customization
While this package comes with a few pre-defined models/datasets/metrics etc. for your convenience, it was designed with quick extensibility in mind. See the sections below for how to add new models, datasets or metrics.
Creating new VP models or integrating existing external models
- Create a file
model_<your name>.py
in the foldervp_suite/models
. - Create a class that derives from
vp_suite.models.base_model.VideoPredictionModel
and override the things you need. - Write your model code or import existing code so that the superclass interface is still served.
If desired, you can implement a custom training loop iteration
train_iter(self, config, loader, optimizer, loss_provider, epoch)
that gets called instead of the default training loop iteration. - Check training performance on different datasets, fix things and contribute to the project 😊
Training on custom datasets
- Create a file
dataset_<your name>.py
in the foldervp_suite/dataset
. - Create a class that derives from
vp_suite.dataset.base_dataset.BaseVPDataset
and override the things you need. - Write your dataset code or import existing code so that the superclass interface is served and the dataset initialization with
vp_suite/dataset/factory.py
still works. - Register it in the
DATASET_CLASSES
dict ofvp_suite/dataset/__init__.py
. - Run pytest, check training performance with different models, fix things and contribute to the project 😊
Custom losses, metrics and optimization
- Create a new file in
vp_suite/measure
, containing your loss or metric. - Make
vp_suite.measure.base_measure.BaseMeasure
its superclass and provide all needed implementations and attributes. - Register the measure in the
METRIC_CLASSES
dict ofvp_suite/measure/__init__.py
and, if it can also be used as a loss, in theLOSS_CLASSES
dict. - Run pytest, check training/evaluation performance with different models and datasets, fix things and contribute to the project 😊
Contributing
This project is always open to extension! It grows especially powerful with more models and datasets, so if you've made your code work on custom models/datasets/metrics/etc., feel free to submit a merge request!
Other kinds of contributions are also very welcome - just check the open issues on the tracker or open up a new issue there.
When submitting a merge request, please make sure all tests run through (execute from root folder):
python -m pytest --runslow --cov=vp_suite
Note: this is the easiest way to run all tests without import hassles.
Omit the runslow
argument to speed up testing by removing the tests for the complete training/testing procedure.
API Documentation
Updating the API documentation can be done by executing build_docs.sh
from the docs/
folder.
Acknowledgements
- Project structure is inspired by segmentation_models.pytorch.
- Sphinx-autodoc templates are inspired by the QNET repository.
All other sources are acknowledged in the documentation of the respective point of usage (to the best of our knowledge).
License
This project comes with an MIT License, except for the following components:
- Module
vp_suite.measure.fvd.pytorch_i3d
(Apache 2.0 License, taken and modified from here)
Disclaimer
I do not host or distribute any dataset. For all provided dataset functionality, I trust you have the permission to download and use the respective data.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file vp-suite-0.0.8.tar.gz
.
File metadata
- Download URL: vp-suite-0.0.8.tar.gz
- Upload date:
- Size: 83.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2ad4bacffe0f781df8e4e489b6b59632124fe5bb7fddf64878135acf49a30574 |
|
MD5 | 73f1354743cb3a8648b9c6ee1bf9742d |
|
BLAKE2b-256 | d2ca44358d82336f003f7c03f7873ae9f88c92253488ef4c3c0f33ddbaf1106f |
File details
Details for the file vp_suite-0.0.8-py3-none-any.whl
.
File metadata
- Download URL: vp_suite-0.0.8-py3-none-any.whl
- Upload date:
- Size: 105.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 46322ba152d9b387ce1e20342db5f9afc05ee7163615e7c60dfb92be6bad7d1e |
|
MD5 | 2ba33890a02ac6dafe4b953ed42982c4 |
|
BLAKE2b-256 | 215617d5451a5893da38021f94ca5977f14a994a7c37115c097bbf1d224bd8d3 |