summaryrefslogtreecommitdiff
path: root/Codes/flownet2/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/flownet2/test.py')
-rw-r--r--Codes/flownet2/test.py163
1 files changed, 163 insertions, 0 deletions
diff --git a/Codes/flownet2/test.py b/Codes/flownet2/test.py
new file mode 100644
index 0000000..2fcb380
--- /dev/null
+++ b/Codes/flownet2/test.py
@@ -0,0 +1,163 @@
+import os
+import tensorflow as tf
+import numpy as np
+from scipy.misc import imread
+import matplotlib
+from src.flowlib import read_flow, flow_to_image
+matplotlib.use('TKAgg')
+import matplotlib.pyplot as plt
+
+_preprocessing_ops = tf.load_op_library(
+ tf.resource_loader.get_path_to_datafile("./src/ops/build/preprocessing.so"))
+
+
+def display(img, c):
+ plt.subplot(int('22' + str(c + 1)))
+ plt.imshow(img[0, :, :, :])
+
+
+def main():
+ """
+.Input("image_a: float32")
+.Input("image_b: float32")
+.Attr("crop: list(int) >= 2")
+.Attr("params_a_name: list(string)")
+.Attr("params_a_rand_type: list(string)")
+.Attr("params_a_exp: list(bool)")
+.Attr("params_a_mean: list(float32)")
+.Attr("params_a_spread: list(float32)")
+.Attr("params_a_prob: list(float32)")
+.Attr("params_b_name: list(string)")
+.Attr("params_b_rand_type: list(string)")
+.Attr("params_b_exp: list(bool)")
+.Attr("params_b_mean: list(float32)")
+.Attr("params_b_spread: list(float32)")
+.Attr("params_b_prob: list(float32)")
+.Output("aug_image_a: float32")
+.Output("aug_image_b: float32")
+.Output("spatial_transform_a: float32")
+.Output("inv_spatial_transform_b: float32")
+ """
+
+ crop = [364, 492]
+ params_a_name = ['translate_x', 'translate_y']
+ params_a_rand_type = ['uniform_bernoulli', 'uniform_bernoulli']
+ params_a_exp = [False, False]
+ params_a_mean = [0.0, 0.0]
+ params_a_spread = [0.4, 0.4]
+ params_a_prob = [1.0, 1.0]
+ params_b_name = []
+ params_b_rand_type = []
+ params_b_exp = []
+ params_b_mean = []
+ params_b_spread = []
+ params_b_prob = []
+
+ with tf.Session() as sess:
+ with tf.device('/gpu:0'):
+ image_a = imread('./img0.ppm') / 255.0
+ image_b = imread('./img1.ppm') / 255.0
+ flow = read_flow('./flow.flo')
+
+ image_a_tf = tf.expand_dims(tf.to_float(tf.constant(image_a, dtype=tf.float64)), 0)
+ image_b_tf = tf.expand_dims(tf.to_float(tf.constant(image_b, dtype=tf.float64)), 0)
+
+ preprocess = _preprocessing_ops.data_augmentation(image_a_tf,
+ image_b_tf,
+ crop,
+ params_a_name,
+ params_a_rand_type,
+ params_a_exp,
+ params_a_mean,
+ params_a_spread,
+ params_a_prob,
+ params_b_name,
+ params_b_rand_type,
+ params_b_exp,
+ params_b_mean,
+ params_b_spread,
+ params_b_prob)
+
+ out = sess.run(preprocess)
+ trans = out.spatial_transform_a
+ inv_trans = out.inv_spatial_transform_b
+
+ print(trans.shape)
+ print(inv_trans.shape)
+
+ flow_tf = tf.expand_dims(tf.to_float(tf.constant(flow)), 0)
+ aug_flow_tf = _preprocessing_ops.flow_augmentation(flow_tf, trans, inv_trans, crop)
+
+ aug_flow = sess.run(aug_flow_tf)[0, :, :, :]
+
+ # Plot img0, img0aug
+ plt.subplot(321)
+ plt.imshow(image_a)
+ plt.subplot(322)
+ plt.imshow(out.aug_image_a[0, :, :, :])
+
+ # Plot img1, img1aug
+ plt.subplot(323)
+ plt.imshow(image_b)
+ plt.subplot(324)
+ plt.imshow(out.aug_image_b[0, :, :, :])
+
+ # Plot flow, flowaug
+ plt.subplot(325)
+ plt.imshow(flow_to_image(flow))
+ plt.subplot(326)
+ plt.imshow(flow_to_image(aug_flow))
+
+ plt.show()
+
+ # image_b_aug = sess.run(image_b_tf)
+ #
+ # display(np.expand_dims(image_a, 0), 0)
+ # display(np.expand_dims(image_b, 0), 1)
+ # display(image_a_aug, 2)
+ # display(image_b_aug, 3)
+ # plt.show()
+
+ # o = _preprocessing_ops.flow_augmentation(flow, trans, inv_t, [4, 8])
+ # print n[:, :, :]
+ # print n[0, 0, 1], n[0, 0, 0]
+ # print n[1, 0, 1], n[1, 0, 0]
+ # print n[2, 0, 1], n[2, 0, 0]
+ # print '---'
+ # print sess.run(o)
+
+ """# Goes along width first!!
+ // Caffe, NKHW: ((n * K + k) * H + h) * W + w at point (n, k, h, w)
+ // TF, NHWK: ((n * H + h) * W + w) * K + k at point (n, h, w, k)
+
+ H=5, W=10, K=2
+ n=0, h=1, w=5, k=0
+
+ (2 * 10) + c
+
+ 30 49 n[0, 1, 5, 0]"""
+
+
+print(os.getpid())
+input("Press Enter to continue...")
+main()
+
+# Last index is channel!!
+
+# K
+
+# value 13 should be at [0, 2, 7, 1] aka batch=0, height=1, width=0, channel=0. it is at index=20.
+#
+# items = {
+# 'N': [0, 0],
+# 'H': [5, 2],
+# 'W': [10, 7],
+# 'K': [2, 1],
+# }
+#
+# for (i1, v1) in items.iteritems():
+# for (i2, v2) in items.iteritems():
+# for (i3, v3) in items.iteritems():
+# for (i4, v4) in items.iteritems():
+# if ((v1[1] * v2[0] + v2[1]) * v3[0] + v3[1]) * v4[0] + v4[1] == 55:
+# print 'found it: ', i1, i2, i3, i4