summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/flownet2/src/utils.py')
-rw-r--r--Codes/flownet2/src/utils.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/Codes/flownet2/src/utils.py b/Codes/flownet2/src/utils.py
new file mode 100644
index 0000000..f6abe18
--- /dev/null
+++ b/Codes/flownet2/src/utils.py
@@ -0,0 +1,46 @@
+import tensorflow as tf
+
+
+# Thanks, https://github.com/tensorflow/tensorflow/issues/4079
+def LeakyReLU(x, leak=0.1, name="lrelu"):
+ with tf.variable_scope(name):
+ f1 = 0.5 * (1.0 + leak)
+ f2 = 0.5 * (1.0 - leak)
+ return f1 * x + f2 * abs(x)
+
+
+def average_endpoint_error(labels, predictions):
+ """
+ Given labels and predictions of size (N, H, W, 2), calculates average endpoint error:
+ sqrt[sum_across_channels{(X - Y)^2}]
+ """
+ num_samples = predictions.shape.as_list()[0]
+ with tf.name_scope(None, "average_endpoint_error", (predictions, labels)) as scope:
+ predictions = tf.to_float(predictions)
+ labels = tf.to_float(labels)
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+
+ squared_difference = tf.square(tf.subtract(predictions, labels))
+ # sum across channels: sum[(X - Y)^2] -> N, H, W, 1
+ loss = tf.reduce_sum(squared_difference, 3, keep_dims=True)
+ loss = tf.sqrt(loss)
+ return tf.reduce_sum(loss) / num_samples
+
+
+def pad(tensor, num=1):
+ """
+ Pads the given tensor along the height and width dimensions with `num` 0s on each side
+ """
+ return tf.pad(tensor, [[0, 0], [num, num], [num, num], [0, 0]], "CONSTANT")
+
+
+def antipad(tensor, num=1):
+ """
+ Performs a crop. "padding" for a deconvolutional layer (conv2d tranpose) removes
+ padding from the output rather than adding it to the input.
+ """
+ batch, h, w, c = tensor.get_shape().as_list()
+ # print(batch, h, w, c)
+ # print(type(batch), type(h), type(w), type(c))
+ # return tf.slice(tensor, begin=[0, num, num, 0], size=[batch, h - 2 * num, w - 2 * num, c])
+ return tensor[:, num: num + h - 2 * num, num: num + w - 2 * num, :]