From fede6ca1dd0077ff509d84bd24028cc7a93bb119 Mon Sep 17 00:00:00 2001 From: StevenLiuWen Date: Tue, 13 Mar 2018 03:28:06 -0400 Subject: first commit --- Codes/flownet2/src/correlation.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 Codes/flownet2/src/correlation.py (limited to 'Codes/flownet2/src/correlation.py') diff --git a/Codes/flownet2/src/correlation.py b/Codes/flownet2/src/correlation.py new file mode 100644 index 0000000..60a5c37 --- /dev/null +++ b/Codes/flownet2/src/correlation.py @@ -0,0 +1,35 @@ +import tensorflow as tf + +_correlation_ops = tf.load_op_library( + tf.resource_loader.get_path_to_datafile("./ops/build/correlation.so")) + + +def correlation(input_a, input_b, kernel_size, max_displacement, stride_1, stride_2, padding): + return _correlation_ops.correlation(input_a, + input_b, + kernel_size, + max_displacement, + stride_1, + stride_2, + padding) + + +@tf.RegisterGradient("Correlation") +def _correlation_grad(corr_op, gradients): + kernel_size = corr_op.get_attr("kernel_size") + max_displacement = corr_op.get_attr("max_displacement") + stride_1 = corr_op.get_attr("stride_1") + stride_2 = corr_op.get_attr("stride_2") + pad = corr_op.get_attr("pad") + + corr_grads = _correlation_ops.correlation_grad(gradients, + corr_op.inputs[0], + corr_op.inputs[1], + kernel_size, + max_displacement, + stride_1, + stride_2, + pad) + + # Return the gradients with respect to input_a and input_b + return corr_grads.backprops_a, corr_grads.backprops_b -- cgit v1.2.3-70-g09d2