diff options
Diffstat (limited to 'Codes/flownet2/src/flownet_cs')
| -rw-r--r-- | Codes/flownet2/src/flownet_cs/__init__.py | 0 | ||||
| -rw-r--r-- | Codes/flownet2/src/flownet_cs/flownet_cs.py | 41 | ||||
| -rw-r--r-- | Codes/flownet2/src/flownet_cs/test.py | 51 | ||||
| -rw-r--r-- | Codes/flownet2/src/flownet_cs/train.py | 21 |
4 files changed, 113 insertions, 0 deletions
diff --git a/Codes/flownet2/src/flownet_cs/__init__.py b/Codes/flownet2/src/flownet_cs/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/Codes/flownet2/src/flownet_cs/__init__.py diff --git a/Codes/flownet2/src/flownet_cs/flownet_cs.py b/Codes/flownet2/src/flownet_cs/flownet_cs.py new file mode 100644 index 0000000..aeaea47 --- /dev/null +++ b/Codes/flownet2/src/flownet_cs/flownet_cs.py @@ -0,0 +1,41 @@ +from ..net import Net, Mode +from ..flownet_c.flownet_c import FlowNetC +from ..flownet_s.flownet_s import FlowNetS +from ..flow_warp import flow_warp +import tensorflow as tf + + +class FlowNetCS(Net): + + def __init__(self, mode=Mode.TRAIN, debug=False): + self.net_c = FlowNetC(mode, debug) + self.net_s = FlowNetS(mode, debug) + super(FlowNetCS, self).__init__(mode=mode, debug=debug) + + def model(self, inputs, training_schedule, trainable=True): + with tf.variable_scope('FlowNetCS'): + # Forward pass through FlowNetC with weights frozen + net_c_predictions = self.net_c.model(inputs, training_schedule, trainable=True) + + # Perform flow warping (to move image B closer to image A based on flow prediction) + warped = flow_warp(inputs['input_b'], net_c_predictions['flow']) + + # Compute brightness error: sqrt(sum (input_a - warped)^2 over channels) + brightness_error = inputs['input_a'] - warped + brightness_error = tf.square(brightness_error) + brightness_error = tf.reduce_sum(brightness_error, keep_dims=True, axis=3) + brightness_error = tf.sqrt(brightness_error) + + # Gather all inputs to FlowNetS + inputs_to_s = { + 'input_a': inputs['input_a'], + 'input_b': inputs['input_b'], + 'warped': warped, + 'flow': net_c_predictions['flow'] * 0.05, + 'brightness_error': brightness_error, + } + + return self.net_s.model(inputs_to_s, training_schedule, trainable=trainable) + + def loss(self, flow, predictions): + return self.net_s.loss(flow, predictions) diff --git a/Codes/flownet2/src/flownet_cs/test.py b/Codes/flownet2/src/flownet_cs/test.py new file mode 100644 index 0000000..ae00ff4 --- /dev/null +++ b/Codes/flownet2/src/flownet_cs/test.py @@ -0,0 +1,51 @@ +import argparse +import os +from ..net import Mode +from .flownet_cs import FlowNetCS + +FLAGS = None + + +def main(): + # Create a new network + net = FlowNetCS(mode=Mode.TEST) + + # Train on the data + net.test( + checkpoint='./checkpoints/FlowNetCS/flownet-CS.ckpt-0', + input_a_path=FLAGS.input_a, + input_b_path=FLAGS.input_b, + out_path=FLAGS.out, + ) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--input_a', + type=str, + required=True, + help='Path to first image' + ) + parser.add_argument( + '--input_b', + type=str, + required=True, + help='Path to second image' + ) + parser.add_argument( + '--out', + type=str, + required=True, + help='Path to output flow result' + ) + FLAGS = parser.parse_args() + + # Verify arguments are valid + if not os.path.exists(FLAGS.input_a): + raise ValueError('image_a path must exist') + if not os.path.exists(FLAGS.input_b): + raise ValueError('image_b path must exist') + if not os.path.isdir(FLAGS.out): + raise ValueError('out directory must exist') + main() diff --git a/Codes/flownet2/src/flownet_cs/train.py b/Codes/flownet2/src/flownet_cs/train.py new file mode 100644 index 0000000..9376132 --- /dev/null +++ b/Codes/flownet2/src/flownet_cs/train.py @@ -0,0 +1,21 @@ +from ..dataloader import load_batch +from ..dataset_configs import FLYING_CHAIRS_DATASET_CONFIG +from ..training_schedules import LONG_SCHEDULE +from .flownet_cs import FlowNetCS + +# Create a new network +net = FlowNetCS() + +# Load a batch of data +input_a, input_b, flow = load_batch(FLYING_CHAIRS_DATASET_CONFIG, 'sample', net.global_step) + +# Train on the data +net.train( + log_dir='./logs/flownet_cs', + training_schedule=LONG_SCHEDULE, + input_a=input_a, + input_b=input_b, + flow=flow, + # Load trained weights for C part of network + checkpoints={'./checkpoints/FlowNetC/flownet-C.ckpt-0': ('FlowNetCS/FlowNetC', 'FlowNetCS')} +) |
