summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/flownet_cs
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/flownet2/src/flownet_cs')
-rw-r--r--Codes/flownet2/src/flownet_cs/__init__.py0
-rw-r--r--Codes/flownet2/src/flownet_cs/flownet_cs.py41
-rw-r--r--Codes/flownet2/src/flownet_cs/test.py51
-rw-r--r--Codes/flownet2/src/flownet_cs/train.py21
4 files changed, 113 insertions, 0 deletions
diff --git a/Codes/flownet2/src/flownet_cs/__init__.py b/Codes/flownet2/src/flownet_cs/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/Codes/flownet2/src/flownet_cs/__init__.py
diff --git a/Codes/flownet2/src/flownet_cs/flownet_cs.py b/Codes/flownet2/src/flownet_cs/flownet_cs.py
new file mode 100644
index 0000000..aeaea47
--- /dev/null
+++ b/Codes/flownet2/src/flownet_cs/flownet_cs.py
@@ -0,0 +1,41 @@
+from ..net import Net, Mode
+from ..flownet_c.flownet_c import FlowNetC
+from ..flownet_s.flownet_s import FlowNetS
+from ..flow_warp import flow_warp
+import tensorflow as tf
+
+
+class FlowNetCS(Net):
+
+ def __init__(self, mode=Mode.TRAIN, debug=False):
+ self.net_c = FlowNetC(mode, debug)
+ self.net_s = FlowNetS(mode, debug)
+ super(FlowNetCS, self).__init__(mode=mode, debug=debug)
+
+ def model(self, inputs, training_schedule, trainable=True):
+ with tf.variable_scope('FlowNetCS'):
+ # Forward pass through FlowNetC with weights frozen
+ net_c_predictions = self.net_c.model(inputs, training_schedule, trainable=True)
+
+ # Perform flow warping (to move image B closer to image A based on flow prediction)
+ warped = flow_warp(inputs['input_b'], net_c_predictions['flow'])
+
+ # Compute brightness error: sqrt(sum (input_a - warped)^2 over channels)
+ brightness_error = inputs['input_a'] - warped
+ brightness_error = tf.square(brightness_error)
+ brightness_error = tf.reduce_sum(brightness_error, keep_dims=True, axis=3)
+ brightness_error = tf.sqrt(brightness_error)
+
+ # Gather all inputs to FlowNetS
+ inputs_to_s = {
+ 'input_a': inputs['input_a'],
+ 'input_b': inputs['input_b'],
+ 'warped': warped,
+ 'flow': net_c_predictions['flow'] * 0.05,
+ 'brightness_error': brightness_error,
+ }
+
+ return self.net_s.model(inputs_to_s, training_schedule, trainable=trainable)
+
+ def loss(self, flow, predictions):
+ return self.net_s.loss(flow, predictions)
diff --git a/Codes/flownet2/src/flownet_cs/test.py b/Codes/flownet2/src/flownet_cs/test.py
new file mode 100644
index 0000000..ae00ff4
--- /dev/null
+++ b/Codes/flownet2/src/flownet_cs/test.py
@@ -0,0 +1,51 @@
+import argparse
+import os
+from ..net import Mode
+from .flownet_cs import FlowNetCS
+
+FLAGS = None
+
+
+def main():
+ # Create a new network
+ net = FlowNetCS(mode=Mode.TEST)
+
+ # Train on the data
+ net.test(
+ checkpoint='./checkpoints/FlowNetCS/flownet-CS.ckpt-0',
+ input_a_path=FLAGS.input_a,
+ input_b_path=FLAGS.input_b,
+ out_path=FLAGS.out,
+ )
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--input_a',
+ type=str,
+ required=True,
+ help='Path to first image'
+ )
+ parser.add_argument(
+ '--input_b',
+ type=str,
+ required=True,
+ help='Path to second image'
+ )
+ parser.add_argument(
+ '--out',
+ type=str,
+ required=True,
+ help='Path to output flow result'
+ )
+ FLAGS = parser.parse_args()
+
+ # Verify arguments are valid
+ if not os.path.exists(FLAGS.input_a):
+ raise ValueError('image_a path must exist')
+ if not os.path.exists(FLAGS.input_b):
+ raise ValueError('image_b path must exist')
+ if not os.path.isdir(FLAGS.out):
+ raise ValueError('out directory must exist')
+ main()
diff --git a/Codes/flownet2/src/flownet_cs/train.py b/Codes/flownet2/src/flownet_cs/train.py
new file mode 100644
index 0000000..9376132
--- /dev/null
+++ b/Codes/flownet2/src/flownet_cs/train.py
@@ -0,0 +1,21 @@
+from ..dataloader import load_batch
+from ..dataset_configs import FLYING_CHAIRS_DATASET_CONFIG
+from ..training_schedules import LONG_SCHEDULE
+from .flownet_cs import FlowNetCS
+
+# Create a new network
+net = FlowNetCS()
+
+# Load a batch of data
+input_a, input_b, flow = load_batch(FLYING_CHAIRS_DATASET_CONFIG, 'sample', net.global_step)
+
+# Train on the data
+net.train(
+ log_dir='./logs/flownet_cs',
+ training_schedule=LONG_SCHEDULE,
+ input_a=input_a,
+ input_b=input_b,
+ flow=flow,
+ # Load trained weights for C part of network
+ checkpoints={'./checkpoints/FlowNetC/flownet-C.ckpt-0': ('FlowNetCS/FlowNetC', 'FlowNetCS')}
+)