diff options
Diffstat (limited to 'Codes/flownet2/src/correlation.py')
| -rw-r--r-- | Codes/flownet2/src/correlation.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/Codes/flownet2/src/correlation.py b/Codes/flownet2/src/correlation.py new file mode 100644 index 0000000..60a5c37 --- /dev/null +++ b/Codes/flownet2/src/correlation.py @@ -0,0 +1,35 @@ +import tensorflow as tf + +_correlation_ops = tf.load_op_library( + tf.resource_loader.get_path_to_datafile("./ops/build/correlation.so")) + + +def correlation(input_a, input_b, kernel_size, max_displacement, stride_1, stride_2, padding): + return _correlation_ops.correlation(input_a, + input_b, + kernel_size, + max_displacement, + stride_1, + stride_2, + padding) + + +@tf.RegisterGradient("Correlation") +def _correlation_grad(corr_op, gradients): + kernel_size = corr_op.get_attr("kernel_size") + max_displacement = corr_op.get_attr("max_displacement") + stride_1 = corr_op.get_attr("stride_1") + stride_2 = corr_op.get_attr("stride_2") + pad = corr_op.get_attr("pad") + + corr_grads = _correlation_ops.correlation_grad(gradients, + corr_op.inputs[0], + corr_op.inputs[1], + kernel_size, + max_displacement, + stride_1, + stride_2, + pad) + + # Return the gradients with respect to input_a and input_b + return corr_grads.backprops_a, corr_grads.backprops_b |
