Code and models accompanying Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning by Bill Lotter, Gabriel Kreiman, and David Cox. The PredNet is a deep recurrent convolutional neural network that is inspired by the neuroscience concept of predictive coding (Rao and Ballard, 1999; Friston, 2005).
If prednet has been uploaded to a devpi instance your pip is connected to, then you can install with:
pip install prednet
You can always install the bleeding-edge updates with:
pip install git+ssh://email@example.com/coxlab/prednet.git@master
Code and models accompanying Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning by Bill Lotter, Gabriel Kreiman, and David Cox.
The PredNet is a deep recurrent convolutional neural network that is inspired by the neuroscience concept of predictive coding (Rao and Ballard, 1999; Friston, 2005). Check out example prediction videoshere.
The architecture is implemented as a custom layer 1 in Keras. Code and model data is compatible with Keras 2.0 and Python 2.7 and 3.6. The latest version has been tested on Keras 2.2.4 with Tensorflow 1.6. For previous versions of the code compatible with Keras 1.2.1, use fbcdc18. To convert old PredNet model files and weights for Keras 2.0 compatibility, see convert_model_to_keras2 in keras_utils.py.
Code is included for training the PredNet on the raw KITTI dataset. We include code for downloading and processing the data, as well as training and evaluating the model. The preprocessed data and can also be downloaded directly using download_data.sh and the trained weights by running download_models.sh. The model download will include the original weights trained for t+1 prediction, the fine-tuned weights trained to extrapolate predictions for multiple timesteps, and the “Lall” weights trained with an 0.1 loss weight on upper layers (see paper for details).
This will scrape the KITTI website to download the raw data from the city, residential, and road categories (~165 GB) and then process the images (cropping, downsampling). Alternatively, the processed data (~3 GB) can be directly downloaded by executing download_data.sh
It’s unclear how to fix the “RuntimeError: Cannot open file.”, so for now use LanaSina’s kitti_hkl at https://figshare.com/articles/KITTI_hkl_files/7985684 https://github.com/coxlab/prednet/issues/53
$ python kitti_train.py Using TensorFlow backend. Epoch 1/150 56/125 [============>.................] - ETA: 8:29 - loss: 0.0649
This will train a PredNet model for t+1 prediction. See Keras FAQ on how to run using a GPU. To download pre-trained weights, run download_models.sh
Currently there are the following warnings on newer versions of TensorFlow:
site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead. site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead. site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead. site-packages/keras/backend/tensorflow_backend.py:2018: The name tf.image.resize_nearest_neighbor is deprecated. Please use tf.compat.v1.image.resize_nearest_neighbor instead. site-packages/keras/backend/tensorflow_backend.py:2018: The name tf.image.resize_nearest_neighbor is deprecated. Please use tf.compat.v1.image.resize_nearest_neighbor instead. site-packages/keras/backend/tensorflow_backend.py:3976: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead. site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead. site-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead. site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead. site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where
This will output the mean-squared error for predictions as well as make plots comparing predictions to ground-truth.
Extracting the intermediate features for a given layer in the PredNet can be done using the appropriate output_mode argument. For example, to extract the hidden state of the LSTM (the “Representation” units) in the lowest layer, use output_mode = 'R0'. More details can be found in the PredNet docstring.
The PredNet argument extrap_start_time can be used to force multi-step prediction. Starting at this time step, the prediction from the previous time step will be treated as the actual input. For example, if the model is run on a sequence of 15 timesteps with extrap_start_time = 10, the last output will correspond to a t+5 prediction. In the paper, we train in this setting starting from the original t+1 trained weights (see kitti_extrap_finetune.py), and the resulting fine-tuned weights are included in download_models.sh. Note that when training with extrapolation, the “errors” are no longer tied to ground truth, so the loss should be calculated on the pixel predictions themselves. This can be done by using output_mode = 'prediction', as illustrated in kitti_extrap_finetune.py.
When training on a new dataset, the image size has to be divisible by 2^(nb of layers - 1) because of the cyclical 2x2 max-pooling and upsampling operations.
1 Note on implementation: PredNet inherits from the Recurrent layer class, i.e. it has an internal state and a step function. Given the top-down then bottom-up update sequence, it must currently be implemented in Keras as essentially a ‘super’ layer where all layers in the PredNet are in one PredNet ‘layer’. This is less than ideal, but it seems like the most efficient way as of now. We welcome suggestions if anyone thinks of a better implementation.
To run the all tests run:
Note, to combine the coverage data from all the tox environments run:
set PYTEST_ADDOPTS=--cov-append tox
0.0.0 (8 July 2016)
- First release on PyPI.
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
|Filename, size||File type||Python version||Upload date||Hashes|
|Filename, size prednet-0.1.dev197-py2.py3-none-any.whl (571.6 kB)||File type Wheel||Python version py2.py3||Upload date||Hashes View hashes|
|Filename, size prednet-0.1.dev197.tar.gz (585.0 kB)||File type Source||Python version None||Upload date||Hashes View hashes|
Hashes for prednet-0.1.dev197-py2.py3-none-any.whl