diff options
Diffstat (limited to 'Codes/flownet2/src/flownet_css')
| -rw-r--r-- | Codes/flownet2/src/flownet_css/__init__.py | 0 | ||||
| -rw-r--r-- | Codes/flownet2/src/flownet_css/flownet_css.py | 41 | ||||
| -rw-r--r-- | Codes/flownet2/src/flownet_css/test.py | 51 | ||||
| -rw-r--r-- | Codes/flownet2/src/flownet_css/train.py | 22 |
4 files changed, 114 insertions, 0 deletions
diff --git a/Codes/flownet2/src/flownet_css/__init__.py b/Codes/flownet2/src/flownet_css/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/Codes/flownet2/src/flownet_css/__init__.py diff --git a/Codes/flownet2/src/flownet_css/flownet_css.py b/Codes/flownet2/src/flownet_css/flownet_css.py new file mode 100644 index 0000000..93d9db2 --- /dev/null +++ b/Codes/flownet2/src/flownet_css/flownet_css.py @@ -0,0 +1,41 @@ +from ..net import Net, Mode +from ..flownet_cs.flownet_cs import FlowNetCS +from ..flownet_s.flownet_s import FlowNetS +from ..flow_warp import flow_warp +import tensorflow as tf + + +class FlowNetCSS(Net): + + def __init__(self, mode=Mode.TRAIN, debug=False): + self.net_cs = FlowNetCS(mode, debug) + self.net_s = FlowNetS(mode, debug) + super(FlowNetCSS, self).__init__(mode=mode, debug=debug) + + def model(self, inputs, training_schedule, trainable=True): + with tf.variable_scope('FlowNetCSS'): + # Forward pass through FlowNetCS with weights frozen + net_cs_predictions = self.net_cs.model(inputs, training_schedule, trainable=True) + + # Perform flow warping (to move image B closer to image A based on flow prediction) + warped = flow_warp(inputs['input_b'], net_cs_predictions['flow']) + + # Compute brightness error: sqrt(sum (input_a - warped)^2 over channels) + brightness_error = inputs['input_a'] - warped + brightness_error = tf.square(brightness_error) + brightness_error = tf.reduce_sum(brightness_error, keep_dims=True, axis=3) + brightness_error = tf.sqrt(brightness_error) + + # Gather all inputs to FlowNetS + inputs_to_s = { + 'input_a': inputs['input_a'], + 'input_b': inputs['input_b'], + 'warped': warped, + 'flow': net_cs_predictions['flow'] * 0.05, + 'brightness_error': brightness_error, + } + + return self.net_s.model(inputs_to_s, training_schedule, trainable=trainable) + + def loss(self, flow, predictions): + return self.net_s.loss(flow, predictions) diff --git a/Codes/flownet2/src/flownet_css/test.py b/Codes/flownet2/src/flownet_css/test.py new file mode 100644 index 0000000..9d1249e --- /dev/null +++ b/Codes/flownet2/src/flownet_css/test.py @@ -0,0 +1,51 @@ +import argparse +import os +from ..net import Mode +from .flownet_css import FlowNetCSS + +FLAGS = None + + +def main(): + # Create a new network + net = FlowNetCSS(mode=Mode.TEST) + + # Train on the data + net.test( + checkpoint='./checkpoints/FlowNetCSS/flownet-CSS.ckpt-0', + input_a_path=FLAGS.input_a, + input_b_path=FLAGS.input_b, + out_path=FLAGS.out, + ) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--input_a', + type=str, + required=True, + help='Path to first image' + ) + parser.add_argument( + '--input_b', + type=str, + required=True, + help='Path to second image' + ) + parser.add_argument( + '--out', + type=str, + required=True, + help='Path to output flow result' + ) + FLAGS = parser.parse_args() + + # Verify arguments are valid + if not os.path.exists(FLAGS.input_a): + raise ValueError('image_a path must exist') + if not os.path.exists(FLAGS.input_b): + raise ValueError('image_b path must exist') + if not os.path.isdir(FLAGS.out): + raise ValueError('out directory must exist') + main() diff --git a/Codes/flownet2/src/flownet_css/train.py b/Codes/flownet2/src/flownet_css/train.py new file mode 100644 index 0000000..2964f3e --- /dev/null +++ b/Codes/flownet2/src/flownet_css/train.py @@ -0,0 +1,22 @@ +from ..dataloader import load_batch +from ..dataset_configs import FLYING_CHAIRS_DATASET_CONFIG +from ..training_schedules import LONG_SCHEDULE +from .flownet_css import FlowNetCSS + +# Create a new network +net = FlowNetCSS() + +# Load a batch of data +input_a, input_b, flow = load_batch(FLYING_CHAIRS_DATASET_CONFIG, 'sample', net.global_step) + +# Train on the data +net.train( + log_dir='./logs/flownet_css', + training_schedule=LONG_SCHEDULE, + input_a=input_a, + input_b=input_b, + flow=flow, + # Load trained weights for CS part of network + checkpoints={ + './checkpoints/FlowNetCS/flownet-CS.ckpt-0': ('FlowNetCSS/FlowNetCS', 'FlowNetCSS')} +) |
