RAFT (Recurrent All Pairs Field Transforms for Optical Flow) implementation via tf.keras
Project description
tf-raft
RAFT (Recurrent All Pairs Field Transforms for Optical Flow, Teed et. al., ECCV2020) implementation via tf.keras
Original resources
Installation
$ pip install tf-raft
or you can simply clone this repository.
Dependencies
- TensorFlow
- TensorFlow-addons
- albumentations
see details in pyoroject.toml
Optical flow datasets
MPI-Sintel or FlyingChairs datasets are relatively light. See more datasets in the oirignal repository
Usage
from tf_raft.model import RAFT, SmallRAFT
from tf_raft.losses import sequence_loss, end_point_error
# iters/iters_pred are the number of recurrent update of flow in training/prediction
raft = RAFT(iters=iters, iters_pred=iters_pred)
raft.compile(
optimizer=optimizer,
clip_norm=clip_norm,
loss=sequence_loss,
epe=end_point_error
)
raft.fit(
ds_train,
epochs=epochs,
callbacks=callbacks,
steps_per_epoch=train_size//batch_size,
validation_data=ds_val,
validation_steps=val_size
)
In practice, you are required to prepare dataset, optimizer, callbacks etc, check details in train_sintel.py or train_chairs.py.
Train via YAML configuration
train_chairs.py and train_sintel.py train RAFT model via YAML configuration. Sample configs are in configs directory. Run;
$ python train_chairs.py /path/to/config.yml
Pre-trained models
I made the pre-trained weights (on both FlyingChairs and MPI-Sintel) public.
You can download them via gsutil or curl.
Trained weights on FlyingChairs
$ gsutil cp -r gs://tf-raft-pretrained/2020-09-26T18-38/checkpoints .
or
$ mkdir checkpoints
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T18-38/checkpoints/model.data-00000-of-00001
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T18-38/checkpoints/model.index
$ mv model* checkpoints/
Trained weights on MPI-Sintel (Clean path)
$ gsutil cp -r gs://tf-raft-pretrained/2020-09-26T08-51/checkpoints .
or
$ mkdir checkpoints
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T08-51/checkpoints/model.data-00000-of-00001
$ curl -OL https://storage.googleapis.com/tf-raft-pretrained/2020-09-26T08-51/checkpoints/model.index
$ mv model* checkpoints/
Load weights
raft = RAFT(iters=iters, iters_pred=iters_pred)
raft.load_weights('checkpoints/model')
# forward (with dummy inputs)
x1 = np.random.uniform(0, 255, (1, 448, 512, 3)).astype(np.float32)
x2 = np.random.uniform(0, 255, (1, 448, 512, 3)).astype(np.float32)
flow_predictions = model([x1, x2], training=False)
print(flow_predictions[-1].shape) # >> (1, 448, 512, 2)
Note
Though I have tried to reproduce the original implementation faithfully, there is some difference between the original one and mine (mainly because of used framework: PyTorch/TensorFlow);
- The original implementations provides cuda-based correlation function but I don't. My TF-based implementation works well, but cuda-based one may run faster.
- I have trained my model on FlyingChairs and MPI-Sintel separately in my private environment (GCP with P100 accelerator). The model has been trained well, but not reached the best score reported in the paper (trained on multiple datasets).
- The original one uses mixed-precision. This may get training much faster, but I don't. TensorFlow also enables mixed-precision with few additional lines, see https://www.tensorflow.org/guide/mixed_precision if interested.
Additional, global gradient clipping seems to be essential for stable training though it is not emphasized in the original paper. This operation can be done via torch.nn.utils.clip_grad_norm_(model.parameters(), clip) in PyTorch, tf.clip_by_global_norm(grads, clip_norm) in TF (coded at self.train_step in tf_raft/model.py).
References
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tf-raft-0.1.4.tar.gz.
File metadata
- Download URL: tf-raft-0.1.4.tar.gz
- Upload date:
- Size: 19.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.0 CPython/3.8.5 Linux/5.4.0-1025-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dea7ff166438f29bde599106c007a549d36ce5b3faa8b32c6f4f5f697ae20af8
|
|
| MD5 |
a97c72380b0777e7d225e9af397923a9
|
|
| BLAKE2b-256 |
02f1e031d0bc7b7dadf115c8e6721fe0468e1bddf8375f49a8993108abcdb955
|
File details
Details for the file tf_raft-0.1.4-py3-none-any.whl.
File metadata
- Download URL: tf_raft-0.1.4-py3-none-any.whl
- Upload date:
- Size: 21.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.0 CPython/3.8.5 Linux/5.4.0-1025-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5aa818ec8a8ffecd742a66fbc3f18eee3d619a1e24e3c534fa50b58184ddacda
|
|
| MD5 |
96c6535cf90a5f96849753fb3a0c20bb
|
|
| BLAKE2b-256 |
19933d137a37310693f56068d75e0fbdb30ea1ca058965a4f604d1fba9d19f07
|