A Framework for Training and Evaluating Video Prediction Models
Project description
Introduction
Video prediction ('VP') is the task of predicting future frames given some context frames.
Like with most Computer Vision sub-domains, scientific contributions in this field exhibit a high variance in the following aspects:
- Training protocol (dataset usage, when to backprop, value ranges etc.)
- Technical details of model implementation (deep learning framework, package dependencies etc.)
- Benchmark selection and execution (this includes the choice of dataset, number of context/predicted frames, skipping frames in the observed sequences etc.)
- Evaluation protocol (metrics chosen, variations in implementation/reduction modes, different ways of creating visualizations etc.)
Furthermore, while many contributors nowadays do share their code, seemingly minor missing parts such as dataloaders etc. make it much harder to assess, compare and improve existing models.
This repo aims at providing a suite that facilitates scientific work in the subfield, providing standardized yet customizable solutions for the aspects mentioned above. This way, validating existing VP models and creating new ones hopefully becomes much less tedious.
Installation
Requires pip
and python >= 3.6
(code is tested with version 3.8
).
From PyPi
pip install vp-suite
From source
pip install git+https://github.com/Flunzmas/vp-suite.git
If you want to contribute
git clone https://github.com/Flunzmas/vp-suite.git
cd vp-suite
pip install -e .[dev]
If you want to build docs
git clone https://github.com/Flunzmas/vp-suite.git
cd vp-suite
pip install -e .[doc]
Usage
Changing save location
When using this package for the first time, the save location for datasets,
models and logs is set to <installation_dir>/vp-suite-data
.
If you'd like to change that, simply run:
python vp_suite/resource/set_run_path.py
This script changes your save location and migrates any existing data.
Training models
from vp_suite import VPSuite
# 1. Set up the VP Suite.
suite = VPSuite()
# 2. Load one of the provided datasets.
# They will be downloaded automatically if no downloaded data is found.
suite.load_dataset("MM") # load moving MNIST dataset from default location
# 3. Create a video prediction model.
suite.create_model('convlstm-shi') # create a ConvLSTM-Based Prediction Model.
# 4. Run the training loop, optionally providing custom configuration.
suite.train(lr=2e-4, epochs=100)
This code snippet will train the model, log training progress to your Weights & Biases account, save model checkpoints on improvement and generate and save prediction visualizations.
Evaluating models
from vp_suite import VPSuite
# 1. Set up the VP Suite.
suite = VPSuite()
# 2. Load one of the provided datasets in test mode.
# They will be downloaded automatically if no downloaded data is found.
suite.load_dataset("MM", split="test") # load moving MNIST dataset from default location
# 3. Get the filepaths to the models you'd like to test and load the models
model_dirs = ["out/model_foo/", "out/model_bar/"]
for model_dir in model_dirs:
suite.load_model(model_dir, ckpt_name="best_model.pth")
# 4. Test the loaded models on the loaded test sets.
suite.test(context_frames=5, pred_frames=10)
This code will evaluate the loaded models on the loaded dataset (its test portion, if avaliable), creating detailed summaries of prediction performance across a customizable set of metrics. The results as well as prediction visualizations are saved and logged to Weights & Biases.
Note 1: If the specified evaluation protocol or the loaded dataset is incompatible with one of the models, this will raise an error with an explanation.
Note 2: By default, a CopyLastFrame baseline is also loaded and tested with the other models.
Hyperparameter Optimization
This package uses optuna to provide hyperparameter optimization functionalities. The following snippet provides a full example:
import json
from vp_suite import VPSuite
from vp_suite.defaults import SETTINGS
suite = VPSuite()
suite.load_dataset(dataset="KTH") # select dataset of choice
suite.create_model(model_id="lstm") # select model of choice
with open(str((SETTINGS.PKG_RESOURCES / "optuna_example_config.json").resolve()), 'r') as cfg_file:
optuna_cfg = json.load(cfg_file)
# optuna_cfg specifies the parameters' search intervals and scales; modify as you wish.
suite.hyperopt(optuna_cfg, n_trials=30, epochs=10)
This code e.g. will run 30 training loops (called trials by optuna), producing a trained model for each hyperparameter configuration and writing the hyperparameter configuration of the best performing run to the console.
Note 1: For hyperopt, visualization, logging and model checkpointing is minimized to reduce IO strain.
Note 2: Despite optuna's trial pruning capabilities, running a high number of trials might still take a lot of time. In that case, consider e.g. reducing the number of training epochs.
Use no_wandb=True
/no_vis=True
if you want to log outputs to the console instead/not generate and save visualizations.
Notes:
- Use
VPSuite.list_available_models()
andVPSuite.list_available_datasets()
to get an overview of which models and datasets are currently covered by the framework. - All training, testing and hyperparametrization calls can be heavily configured (adjusting training hyperparameters, logging behavior etc, ...).
For a comprehensive list of all adjustable run configuration parameters see the documentation of the
vp_suite.defaults
package.
Customization
This package is designed with quick extensibility in mind. See the sections below for how to add new components (models, model blocks, datasets or measures).
New Models
- Create a file
<your name>.py
in the foldervp_suite/models
. - Create a class that derives from
vp_suite.base.base_model.VideoPredictionModel
and override/specify new constants you need. - Write your model code or import existing code so that the superclass interface is still served.
If desired, you can implement a custom training/evaluation loop iteration
train_iter()
/eval_iter()
that gets called instead of the default training/evaluation loop iteration. - Register your model in the
MODEL_CLASSES
dictionary ofvp_suite/models/__init__.py
, giving it a key that can be used by the suite. By now, you should be able to create an instance of your model withVPSuite.create_model()
and train it on a dataset withVPSuite.train()
.
New Model Blocks
- Create a file
<your name>.py
in the foldervp_suite/model_blocks
. - Create a class that derives from
vp_suite.base.base_model_block.ModelBlock
and override/specify new constants you need. - Write your model block code or import existing code so that the superclass interface is still served.
- If desired, add a local import of your model block to
vp_suite/model_blocks/__init__.py
(this registers the model block package-wide).
New Datasets
- Create a file
<your name>.py
in the foldervp_suite/datasets
. - Create a class that derives from
vp_suite.base.base_dataset.BaseVPDataset
and override/specify new constants you need. - Write your dataset code or import existing code so that the superclass interface is served. If it's a public dataset, consider implementing methods to automatically download it.
- Register your dataset in the
DATASET_CLASSES
dict ofvp_suite/dataset/__init__.py
, giving it a key that can be used by the suite. By now, you should be able to load your dataset withVPSuite.load_dataset()
and train models on it withVPSuite.train()
.
New measures (losses and/or metrics)
- Create a new file
<your name>.py
in the foldervp_suite/measure
, containing your loss or metric. - Make
vp_suite.base.base_measure.BaseMeasure
its superclass and override/implement all needed implementations and constants. - Register the measure in the
METRIC_CLASSES
dict ofvp_suite/measure/__init__.py
and, if it can also be used as a loss, in theLOSS_CLASSES
dict.
Notes:
- If you omit the docstring for a particular attribute/method/field, the docstring of the base class is used for documentation.
- If implementing components that originate from publications/public repositories, please override the corresponding constants to specify the source!
Additionally, if you want to write automated tests checking implementation equality,
have a look at how
tests/test_impl_match.py
fetches the tests oftests/test_impl_match/
and executes these tests. - Basic unit tests for models, datasets and measures are executed on all registered models - you don't need to write such basic tests for your custom components! Same applies for documentation: The tables that list available components are filled automatically.
Contributing
This project is always open to extension! It grows especially powerful with more models and datasets, so if you've made your code work on custom models/datasets/metrics/etc., feel free to submit a merge request!
Other kinds of contributions are also very welcome - just check the open issues on the tracker or open up a new issue there.
Unit Testing
When submitting a merge request, please make sure all tests run through (execute from root folder):
python -m pytest --runslow --cov=vp_suite
Note: this is the easiest way to run all tests without import hassles.
You will need to have vp-suite
installed in development move, though (see here).
API Documentation
The official API documentation is updated automatically upon push to the main branch. If you want to build the documentation locally, make sure you've installed the package accordingly and execute the following:
cd docs/
bash assemble_docs.sh
Acknowledgements
- Project structure is inspired by segmentation_models.pytorch.
- Sphinx-autodoc templates are inspired by the QNET repository.
All other sources are acknowledged in the documentation of the respective point of usage (to the best of our knowledge).
License
This project comes with an MIT License, except for the following components:
- Module
vp_suite.measure.fvd.pytorch_i3d
(Apache 2.0 License, taken and modified from here)
Disclaimer
I do not host or distribute any dataset. For all provided dataset functionality, I trust you have the permission to download and use the respective data.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file vp-suite-0.0.9.tar.gz
.
File metadata
- Download URL: vp-suite-0.0.9.tar.gz
- Upload date:
- Size: 99.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9c2ace6e42f56e0d47005e4459d9b16d1e3259e1b16a373e825739abacf1bb33 |
|
MD5 | a80cd5a2f82063448076a5ca658dfe95 |
|
BLAKE2b-256 | 8e6d533b1f7a38d423fd7051710182eae5551a6de66994bf319768331e9e4304 |
File details
Details for the file vp_suite-0.0.9-py3-none-any.whl
.
File metadata
- Download URL: vp_suite-0.0.9-py3-none-any.whl
- Upload date:
- Size: 122.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | eec438c02ced3021817712209e4e575897dff57d1fc61c3e6581b78355dd6800 |
|
MD5 | cd80369bbe622d34f1afa368acafb1a5 |
|
BLAKE2b-256 | b03a4ba2162878170e2cad14ec0ea289481a95f470a0aad6fc17fd4039df77b2 |