summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/correlation.py
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/flownet2/src/correlation.py')
-rw-r--r--Codes/flownet2/src/correlation.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/Codes/flownet2/src/correlation.py b/Codes/flownet2/src/correlation.py
new file mode 100644
index 0000000..60a5c37
--- /dev/null
+++ b/Codes/flownet2/src/correlation.py
@@ -0,0 +1,35 @@
+import tensorflow as tf
+
+_correlation_ops = tf.load_op_library(
+ tf.resource_loader.get_path_to_datafile("./ops/build/correlation.so"))
+
+
+def correlation(input_a, input_b, kernel_size, max_displacement, stride_1, stride_2, padding):
+ return _correlation_ops.correlation(input_a,
+ input_b,
+ kernel_size,
+ max_displacement,
+ stride_1,
+ stride_2,
+ padding)
+
+
+@tf.RegisterGradient("Correlation")
+def _correlation_grad(corr_op, gradients):
+ kernel_size = corr_op.get_attr("kernel_size")
+ max_displacement = corr_op.get_attr("max_displacement")
+ stride_1 = corr_op.get_attr("stride_1")
+ stride_2 = corr_op.get_attr("stride_2")
+ pad = corr_op.get_attr("pad")
+
+ corr_grads = _correlation_ops.correlation_grad(gradients,
+ corr_op.inputs[0],
+ corr_op.inputs[1],
+ kernel_size,
+ max_displacement,
+ stride_1,
+ stride_2,
+ pad)
+
+ # Return the gradients with respect to input_a and input_b
+ return corr_grads.backprops_a, corr_grads.backprops_b