Skip to main content

RAFT (Recurrent All Pairs Field Transforms for Optical Flow) implementation via tf.keras

Project description

tf-raft

RAFT (Recurrent All Pairs Field Transforms for Optical Flow, Teed et. al., ECCV2020) implementation via tf.keras

Original resources

Installation

$ pip install tf-raft

or you can simply clone this repository.

Dependencies

  • TensorFlow
  • TensorFlow-addons
  • albumentations

see details in pyoroject.toml

Optical flow datasets

MPI-Sintel or FlyingChairs datasets are relatively light. See more datasets in the oirignal repository

Usage

from tf_raft.model import RAFT, SmallRAFT
from tf_raft.losses import sequence_loss, end_point_error

# iters/iters_pred are the number of recurrent update of flow in training/prediction
raft = RAFT(iters=iters, iters_pred=iters_pred)
raft.compile(
    optimizer=optimizer,
    clip_norm=clip_norm,
    loss=sequence_loss,
    epe=end_point_error
)

raft.fit(
    ds_train,
    epochs=epochs,
    callbacks=callbacks,
    steps_per_epoch=train_size//batch_size,
    validation_data=ds_val,
    validation_steps=val_size
)

In practice, you are required to prepare dataset, optimizer, callbacks etc, check details in train_sintel.py or train_chairs.py.

Train via YAML configuration

train_chairs.py and train_sintel.py train RAFT model via YAML configuration. Sample configs are in configs directory. Run;

$ python train_chairs.py /path/to/config.yml

Pre-trained models

I made the pre-trained weights (on both FlyingChairs and MPI-Sintel) public. You can download them via gsutil or curl.

Trained weights on FlyingChairs

$ gsutil cp -r gs://tf-raft-pretrained/2020-09-26T18-38/checkpoints .

or

$ mkdir checkpoints
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T18-38/checkpoints/model.data-00000-of-00001
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T18-38/checkpoints/model.index
$ mv model* checkpoints/

Trained weights on MPI-Sintel (Clean path)

$ gsutil cp -r gs://tf-raft-pretrained/2020-09-26T08-51/checkpoints .

or

$ mkdir checkpoints
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T08-51/checkpoints/model.data-00000-of-00001
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T08-51/checkpoints/model.index
$ mv model* checkpoints/

Load weights

raft = RAFT(iters=iters, iters_pred=iters_pred)
raft.load_weights('checkpoints/model')

# forward (with dummy inputs)
x1 = np.random.uniform(0, 255, (1, 448, 512, 3)).astype(np.float32)
x2 = np.random.uniform(0, 255, (1, 448, 512, 3)).astype(np.float32)
flow_predictions = model([x1, x2], training=False)

print(flow_predictions[-1].shape) # >> (1, 448, 512, 2)

Note

Though I have tried to reproduce the original implementation faithfully, there is some difference between the original one and mine (mainly because of used framework: PyTorch/TensorFlow);

  • The original implementations provides cuda-based correlation function but I don't. My TF-based implementation works well, but cuda-based one may run faster.
  • I have trained my model on FlyingChairs and MPI-Sintel separately in my private environment (GCP with P100 accelerator). The model has been trained well, but not reached the best score reported in the paper (trained on multiple datasets).
  • The original one uses mixed-precision. This may get training much faster, but I don't. TensorFlow also enables mixed-precision with few additional lines, see https://www.tensorflow.org/guide/mixed_precision if interested.

Additional, global gradient clipping seems to be essential for stable training though it is not emphasized in the original paper. This operation can be done via torch.nn.utils.clip_grad_norm_(model.parameters(), clip) in PyTorch, tf.clip_by_global_norm(grads, clip_norm) in TF (coded at self.train_step in tf_raft/model.py).

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf-raft-0.1.4.tar.gz (19.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tf_raft-0.1.4-py3-none-any.whl (21.5 kB view details)

Uploaded Python 3

File details

Details for the file tf-raft-0.1.4.tar.gz.

File metadata

  • Download URL: tf-raft-0.1.4.tar.gz
  • Upload date:
  • Size: 19.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.0 CPython/3.8.5 Linux/5.4.0-1025-azure

File hashes

Hashes for tf-raft-0.1.4.tar.gz
Algorithm Hash digest
SHA256 dea7ff166438f29bde599106c007a549d36ce5b3faa8b32c6f4f5f697ae20af8
MD5 a97c72380b0777e7d225e9af397923a9
BLAKE2b-256 02f1e031d0bc7b7dadf115c8e6721fe0468e1bddf8375f49a8993108abcdb955

See more details on using hashes here.

File details

Details for the file tf_raft-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: tf_raft-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 21.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.0 CPython/3.8.5 Linux/5.4.0-1025-azure

File hashes

Hashes for tf_raft-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 5aa818ec8a8ffecd742a66fbc3f18eee3d619a1e24e3c534fa50b58184ddacda
MD5 96c6535cf90a5f96849753fb3a0c20bb
BLAKE2b-256 19933d137a37310693f56068d75e0fbdb30ea1ca058965a4f604d1fba9d19f07

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page