1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
from ..net import Net, Mode
from ..flownet_c.flownet_c import FlowNetC
from ..flownet_s.flownet_s import FlowNetS
from ..flow_warp import flow_warp
import tensorflow as tf
class FlowNetCS(Net):
def __init__(self, mode=Mode.TRAIN, debug=False):
self.net_c = FlowNetC(mode, debug)
self.net_s = FlowNetS(mode, debug)
super(FlowNetCS, self).__init__(mode=mode, debug=debug)
def model(self, inputs, training_schedule, trainable=True):
with tf.variable_scope('FlowNetCS'):
# Forward pass through FlowNetC with weights frozen
net_c_predictions = self.net_c.model(inputs, training_schedule, trainable=True)
# Perform flow warping (to move image B closer to image A based on flow prediction)
warped = flow_warp(inputs['input_b'], net_c_predictions['flow'])
# Compute brightness error: sqrt(sum (input_a - warped)^2 over channels)
brightness_error = inputs['input_a'] - warped
brightness_error = tf.square(brightness_error)
brightness_error = tf.reduce_sum(brightness_error, keep_dims=True, axis=3)
brightness_error = tf.sqrt(brightness_error)
# Gather all inputs to FlowNetS
inputs_to_s = {
'input_a': inputs['input_a'],
'input_b': inputs['input_b'],
'warped': warped,
'flow': net_c_predictions['flow'] * 0.05,
'brightness_error': brightness_error,
}
return self.net_s.model(inputs_to_s, training_schedule, trainable=trainable)
def loss(self, flow, predictions):
return self.net_s.loss(flow, predictions)
|