diff options
| author | StevenLiuWen <liuwen@shanghaitech.edu.cn> | 2018-03-13 03:28:06 -0400 |
|---|---|---|
| committer | StevenLiuWen <liuwen@shanghaitech.edu.cn> | 2018-03-13 03:28:06 -0400 |
| commit | fede6ca1dd0077ff509d84bd24028cc7a93bb119 (patch) | |
| tree | af7f6e759b5dec4fc2964daed09e903958b919ed /Codes/flownet2/src/flownet_cs/flownet_cs.py | |
first commit
Diffstat (limited to 'Codes/flownet2/src/flownet_cs/flownet_cs.py')
| -rw-r--r-- | Codes/flownet2/src/flownet_cs/flownet_cs.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/Codes/flownet2/src/flownet_cs/flownet_cs.py b/Codes/flownet2/src/flownet_cs/flownet_cs.py new file mode 100644 index 0000000..aeaea47 --- /dev/null +++ b/Codes/flownet2/src/flownet_cs/flownet_cs.py @@ -0,0 +1,41 @@ +from ..net import Net, Mode +from ..flownet_c.flownet_c import FlowNetC +from ..flownet_s.flownet_s import FlowNetS +from ..flow_warp import flow_warp +import tensorflow as tf + + +class FlowNetCS(Net): + + def __init__(self, mode=Mode.TRAIN, debug=False): + self.net_c = FlowNetC(mode, debug) + self.net_s = FlowNetS(mode, debug) + super(FlowNetCS, self).__init__(mode=mode, debug=debug) + + def model(self, inputs, training_schedule, trainable=True): + with tf.variable_scope('FlowNetCS'): + # Forward pass through FlowNetC with weights frozen + net_c_predictions = self.net_c.model(inputs, training_schedule, trainable=True) + + # Perform flow warping (to move image B closer to image A based on flow prediction) + warped = flow_warp(inputs['input_b'], net_c_predictions['flow']) + + # Compute brightness error: sqrt(sum (input_a - warped)^2 over channels) + brightness_error = inputs['input_a'] - warped + brightness_error = tf.square(brightness_error) + brightness_error = tf.reduce_sum(brightness_error, keep_dims=True, axis=3) + brightness_error = tf.sqrt(brightness_error) + + # Gather all inputs to FlowNetS + inputs_to_s = { + 'input_a': inputs['input_a'], + 'input_b': inputs['input_b'], + 'warped': warped, + 'flow': net_c_predictions['flow'] * 0.05, + 'brightness_error': brightness_error, + } + + return self.net_s.model(inputs_to_s, training_schedule, trainable=trainable) + + def loss(self, flow, predictions): + return self.net_s.loss(flow, predictions) |
