Skip to main content

RAFT (Recurrent All Pairs Field Transforms for Optical Flow) implementation via tf.keras

Project description

tf-raft

RAFT (Recurrent All Pairs Field Transforms for Optical Flow, Teed et. al., ECCV2020) implementation via tf.keras

Original resources

Installation

$ pip install tf-raft

or you can simply clone this repository.

Dependencies

  • TensorFlow
  • TensorFlow-addons
  • albumentations

see details in pyoroject.toml

Optical flow datasets

MPI-Sintel or FlyingChairs datasets are relatively light. See more datasets in the oirignal repository

Usage

from tf_raft.model import RAFT, SmallRAFT
from tf_raft.losses import sequence_loss, end_point_error

# iters means number of recurrent update of flow 
raft = RAFT(iters=iters)
raft.compile(
    optimizer=optimizer,
    loss=sequence_loss,
    epe=end_point_error
)

raft.fit(
    dataset,
    epochs=epochs,
    callbacks=callbacks,
)

In practice, you are required to prepare dataset, optimizer, callbacks etc, check details in train.py.

Note

Though I have tried to reproduce the original implementation faithfully, there is some difference between it and my implementation (mainly because of used framework: PyTorch/TensorFlow);

  • The original implements cuda-based correlation function but I don't. My TF-based implementation works well, but cuda-based one may runs faster.
  • I have trained my model only on MPI-Sintel dataset in my private environment (GCP with P100 accelerator). The model has been trained well, but not reached the best score reported in the paper (trained on multiple datasets).
  • The original uses mixed-precision. This may get traininig much faster, but I don't. TensorFlow also enables mixed-precision with few additional lines, see https://www.tensorflow.org/guide/mixed_precision if interested.

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf-raft-0.1.2.tar.gz (17.9 kB view hashes)

Uploaded Source

Built Distribution

tf_raft-0.1.2-py3-none-any.whl (20.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page