summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/ops/correlation
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/flownet2/src/ops/correlation')
-rw-r--r--Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cc160
-rw-r--r--Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cu.cc262
-rw-r--r--Codes/flownet2/src/ops/correlation/correlation_kernel.cc137
-rw-r--r--Codes/flownet2/src/ops/correlation/correlation_kernel.cu.cc153
-rw-r--r--Codes/flownet2/src/ops/correlation/correlation_kernel.h77
-rw-r--r--Codes/flownet2/src/ops/correlation/correlation_op.cc83
-rw-r--r--Codes/flownet2/src/ops/correlation/pad.cu.cc76
-rw-r--r--Codes/flownet2/src/ops/correlation/pad.h20
8 files changed, 968 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cc b/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cc
new file mode 100644
index 0000000..4e92f45
--- /dev/null
+++ b/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cc
@@ -0,0 +1,160 @@
+#define EIGEN_USE_THREADS
+
+#include "correlation_kernel.h"
+#include "pad.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+template<typename Device>
+class CorrelationGradKernel : public OpKernel {
+ public:
+ explicit CorrelationGradKernel(OpKernelConstruction *ctx) : OpKernel(ctx) {
+ // Get the attributes
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("kernel_size", &kernel_size));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max_displacement", &max_displacement));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("stride_1", &stride_1));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("stride_2", &stride_2));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pad", &pad));
+
+ OP_REQUIRES(ctx, kernel_size % 2 != 0, errors::InvalidArgument("kernel_size must be odd"));
+ }
+
+ void Compute(OpKernelContext *ctx) override {
+ // Get the input images and verify their dimensions
+ const Tensor& gradients_t = ctx->input(0);
+ const Tensor& input_a_t = ctx->input(1);
+ const Tensor& input_b_t = ctx->input(2);
+
+ OP_REQUIRES(ctx, input_a_t.dims() == 4, errors::InvalidArgument("input_a must have rank 4"));
+ OP_REQUIRES(ctx, input_b_t.dims() == 4, errors::InvalidArgument("input_b must have rank 4"));
+
+ // Get dimensions of input
+ const int batch_size = input_a_t.dim_size(0);
+ const int in_height = input_a_t.dim_size(1);
+ const int in_width = input_a_t.dim_size(2);
+ const int in_channels = input_a_t.dim_size(3);
+ const int in_count_per_sample = in_height * in_width * in_channels;
+ const int padded_height = in_height + 2 * pad;
+ const int padded_width = in_width + 2 * pad;
+
+ // The size of unreachable border region on each side
+ const int kernel_radius = (kernel_size - 1) / 2;
+ const int border_size = max_displacement + kernel_radius;
+
+ // Calculate the output dimensions
+ const int out_height = ceil((float)(padded_height - border_size * 2) / (float)stride_1);
+ const int out_width = ceil((float)(padded_width - border_size * 2) / (float)stride_1);
+
+ const int neighborhood_grid_radius = max_displacement / stride_2;
+ const int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
+ const int out_channels = neighborhood_grid_width * neighborhood_grid_width;
+
+ // Allocate the memory for the outputs
+ Tensor *output_a_gradient_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input_a_t.shape(), &output_a_gradient_t));
+ Tensor *output_b_gradient_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(1, input_b_t.shape(), &output_b_gradient_t));
+
+ // Get the tensors
+ auto gradients = gradients_t.tensor<float, 4>();
+ auto input_a = input_a_t.tensor<float, 4>();
+ auto input_b = input_b_t.tensor<float, 4>();
+ auto output_a_gradient = output_a_gradient_t->tensor<float, 4>();
+ auto output_b_gradient = output_b_gradient_t->tensor<float, 4>();
+
+ // Create temporary tensors for padded inputs
+ Tensor padded_input_a_t, padded_input_b_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<float>::value,
+ TensorShape({ batch_size, padded_height, padded_width, in_channels }),
+ &padded_input_a_t));
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<float>::value,
+ TensorShape({ batch_size, padded_height, padded_width, in_channels }),
+ &padded_input_b_t));
+ auto padded_input_a = padded_input_a_t.tensor<float, 4>();
+ auto padded_input_b = padded_input_b_t.tensor<float, 4>();
+
+ // Pad the inputs
+ Pad(ctx->eigen_device<Device>(),
+ input_a.data(),
+ batch_size,
+ in_height,
+ in_width,
+ in_channels,
+ padded_height,
+ padded_width,
+ padded_input_a.data());
+ Pad(ctx->eigen_device<Device>(),
+ input_b.data(),
+ batch_size,
+ in_height,
+ in_width,
+ in_channels,
+ padded_height,
+ padded_width,
+ padded_input_b.data());
+
+ CorrelationGradA(ctx->eigen_gpu_device(),
+ batch_size,
+ out_width,
+ out_height,
+ out_channels,
+ max_displacement,
+ neighborhood_grid_radius,
+ neighborhood_grid_width,
+ kernel_radius,
+ stride_1,
+ stride_2,
+ in_width,
+ in_height,
+ padded_width,
+ padded_height,
+ in_channels,
+ in_count_per_sample,
+ pad,
+ padded_input_b.data(),
+ gradients.data(),
+ output_a_gradient.data());
+
+ CorrelationGradB(ctx->eigen_gpu_device(),
+ batch_size,
+ out_width,
+ out_height,
+ out_channels,
+ max_displacement,
+ neighborhood_grid_radius,
+ neighborhood_grid_width,
+ kernel_radius,
+ stride_1,
+ stride_2,
+ in_width,
+ in_height,
+ padded_width,
+ padded_height,
+ in_channels,
+ in_count_per_sample,
+ pad,
+ padded_input_a.data(),
+ gradients.data(),
+ output_b_gradient.data());
+ }
+
+ private:
+ int kernel_size;
+ int max_displacement;
+ int stride_1;
+ int stride_2;
+ int pad;
+};
+
+REGISTER_KERNEL_BUILDER(Name("CorrelationGrad")
+ .Device(DEVICE_GPU),
+ CorrelationGradKernel<GPUDevice>)
+} // end namespace tensorflow
diff --git a/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cu.cc b/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cu.cc
new file mode 100644
index 0000000..19e3a40
--- /dev/null
+++ b/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cu.cc
@@ -0,0 +1,262 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#define ROUND_OFF 50000
+
+#include <stdio.h>
+#include <iostream>
+
+#include "correlation_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+__global__ void CorrelateDataBackward0(const int nthreads,
+ int item,
+ int out_width,
+ int out_height,
+ int out_channels,
+ int max_displacement,
+ int neighborhood_grid_radius,
+ int neighborhood_grid_width,
+ int kernel_radius,
+ int stride_1,
+ int stride_2,
+ int in_width,
+ int in_height,
+ int padded_in_width,
+ int padded_in_height,
+ int in_channels,
+ int in_count_per_sample,
+ int pad_size,
+ float *output_a_gradient,
+ const float *input_b,
+ const float *gradient)
+{
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ int k = index % in_channels; // channels
+ int x = (index / in_channels) % in_width + pad_size; // w-pos
+ int y = (index / in_channels / in_width) % in_height + pad_size; // h-pos
+
+ // Get X,Y ranges and clamp
+ // round_off is a trick to enable integer division with ceil, even for
+ // negative numbers
+ // We use a large offset, for the inner part not to become negative.
+ const int round_off = ROUND_OFF;
+ const int round_off_s1 = stride_1 * round_off;
+
+ // We add round_off before_s1 the int division and subtract round_off after
+ // it, to ensure the formula matches ceil behavior:
+ int xmin = (x - 2 * kernel_radius - max_displacement + round_off_s1 - 1) / stride_1 + 1 -
+ round_off;
+ int ymin = (y - 2 * kernel_radius - max_displacement + round_off_s1 - 1) / stride_1 + 1 -
+ round_off;
+
+ // Same here:
+ int xmax = (x - max_displacement + round_off_s1) / stride_1 - round_off;
+ int ymax = (y - max_displacement + round_off_s1) / stride_1 - round_off;
+
+ float sum = 0;
+
+ if ((xmax >= 0) && (ymax >= 0) && (xmin <= out_width - 1) && (ymin <= out_height - 1)) {
+ xmin = max(0, xmin);
+ xmax = min(out_width - 1, xmax);
+
+ ymin = max(0, ymin);
+ ymax = min(out_height - 1, ymax);
+
+ for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) {
+ for (int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) {
+ // Get input_b data:
+ int s2o = stride_2 * o;
+ int s2p = stride_2 * p;
+ int idx_input_b = ((item * padded_in_height + (y + s2p)) * padded_in_width + (x + s2o)) *
+ in_channels + k;
+ float input_b_tmp = input_b[idx_input_b]; // input_b[x+s2o,y+s2p,k]
+
+ // Index offset for gradient in following loops:
+ int op = (p + neighborhood_grid_radius) * neighborhood_grid_width +
+ (o + neighborhood_grid_radius); // index [o,p]
+
+ for (int y = ymin; y <= ymax; y++) {
+ for (int x = xmin; x <= xmax; x++) {
+ // gradient[x,y,o,p]
+ int idx_gradient = ((item * out_height + y) * out_width + x) * out_channels + op;
+ sum += gradient[idx_gradient] * input_b_tmp;
+ }
+ }
+ }
+ }
+ }
+ const int sumelems = (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * in_channels;
+ const int input_a_idx = ((y - pad_size) * in_width + (x - pad_size)) * in_channels + k;
+ output_a_gradient[input_a_idx + item * in_count_per_sample] = sum / (float)sumelems;
+ }
+}
+
+__global__ void CorrelateDataBackward1(const int nthreads,
+ int item,
+ int out_width,
+ int out_height,
+ int out_channels,
+ int max_displacement,
+ int neighborhood_grid_radius,
+ int neighborhood_grid_width,
+ int kernel_radius,
+ int stride_1,
+ int stride_2,
+ int in_width,
+ int in_height,
+ int padded_in_width,
+ int padded_in_height,
+ int in_channels,
+ int in_count_per_sample,
+ int pad_size,
+ float *output_b_gradient,
+ const float *input_a,
+ const float *gradient)
+{
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ int k = index % in_channels; // channels
+ int x = (index / in_channels) % in_width + pad_size; // w-pos
+ int y = (index / in_channels / in_width) % in_height + pad_size; // h-pos
+
+ // round_off is a trick to enable integer division with ceil, even for
+ // negative numbers
+ // We use a large offset, for the inner part not to become negative.
+ const int round_off = ROUND_OFF;
+ const int round_off_s1 = stride_1 * round_off;
+
+ float sum = 0;
+
+ // Height (y)
+ for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) {
+ // Width (x)
+ for (int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) {
+ int s2o = stride_2 * o;
+ int s2p = stride_2 * p;
+
+ // Get X,Y ranges and clamp
+ // We add round_off before_s1 the int division and subtract round_off
+ // after it, to ensure the formula matches ceil behavior:
+ int xmin = (x - 2 * kernel_radius - max_displacement - s2o + round_off_s1 - 1) / stride_1 +
+ 1 - round_off;
+ int ymin = (y - 2 * kernel_radius - max_displacement - s2p + round_off_s1 - 1) / stride_1 +
+ 1 - round_off;
+
+ // Caffe, NKHW: ((n * K + k) * H + h) * W + w at point (n, k, h, w)
+ // TF, NHWK: ((n * H + h) * W + w) * K + k at point (n, h, w, k)
+
+ // Same here:
+ int xmax = (x - max_displacement - s2o + round_off_s1) / stride_1 - round_off;
+ int ymax = (y - max_displacement - s2p + round_off_s1) / stride_1 - round_off;
+
+ if ((xmax >= 0) && (ymax >= 0) && (xmin <= out_width - 1) && (ymin <= out_height - 1)) {
+ xmin = max(0, xmin);
+ xmax = min(out_width - 1, xmax);
+
+ ymin = max(0, ymin);
+ ymax = min(out_height - 1, ymax);
+
+ // Get input_a data:
+ int idx_input_a = ((item * padded_in_height + (y - s2p)) * padded_in_width + (x - s2o)) *
+ in_channels + k;
+ float input_a_tmp = input_a[idx_input_a];
+
+ // Index offset for gradient in following loops:
+ int op = (p + neighborhood_grid_radius) * neighborhood_grid_width +
+ (o + neighborhood_grid_radius); // index [o,p]
+
+ for (int y = ymin; y <= ymax; y++) {
+ for (int x = xmin; x <= xmax; x++) {
+ int idx_gradient = ((item * out_height + y) * out_width + x) * out_channels + op;
+ sum += gradient[idx_gradient] * input_a_tmp;
+ }
+ }
+ }
+ }
+ }
+ const int sumelems = (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * in_channels;
+ const int input_b_idx = ((y - pad_size) * in_width + (x - pad_size)) * in_channels + k;
+ output_b_gradient[input_b_idx + item * in_count_per_sample] = sum / (float)sumelems;
+ }
+}
+
+void CorrelationGradA(const GPUDevice& device,
+ const int batch_size,
+ const int out_width,
+ const int out_height,
+ const int out_channels,
+ const int max_displacement,
+ const int neighborhood_grid_radius,
+ const int neighborhood_grid_width,
+ const int kernel_radius,
+ const int stride_1,
+ const int stride_2,
+ const int in_width,
+ const int in_height,
+ const int padded_in_width,
+ const int padded_in_height,
+ const int in_channels,
+ const int in_count_per_sample, // h * w * ch
+ const int pad,
+ const float *input_b,
+ const float *gradient,
+ float *output_a_gradient) {
+ CudaLaunchConfig config = GetCudaLaunchConfig(in_count_per_sample, device);
+
+ for (int n = 0; n < batch_size; n++) {
+ CorrelateDataBackward0 << < config.block_count, config.thread_per_block, 0,
+ device.stream() >> > (
+ in_count_per_sample,
+ n, out_width, out_height, out_channels,
+ max_displacement, neighborhood_grid_radius, neighborhood_grid_width, kernel_radius,
+ stride_1, stride_2,
+ in_width, in_height, padded_in_width, padded_in_height, in_channels, in_count_per_sample, pad,
+ output_a_gradient, input_b, gradient);
+ }
+}
+
+void CorrelationGradB(const GPUDevice& device,
+ const int batch_size,
+ const int out_width,
+ const int out_height,
+ const int out_channels,
+ const int max_displacement,
+ const int neighborhood_grid_radius,
+ const int neighborhood_grid_width,
+ const int kernel_radius,
+ const int stride_1,
+ const int stride_2,
+ const int in_width,
+ const int in_height,
+ const int padded_in_width,
+ const int padded_in_height,
+ const int in_channels,
+ const int in_count_per_sample,
+ const int pad,
+ const float *input_a,
+ const float *gradient,
+ float *output_b_gradient) {
+ CudaLaunchConfig config = GetCudaLaunchConfig(in_count_per_sample, device);
+
+ for (int n = 0; n < batch_size; n++) {
+ CorrelateDataBackward1 << < config.block_count, config.thread_per_block, 0,
+ device.stream() >> > (
+ in_count_per_sample,
+ n, out_width, out_height, out_channels,
+ max_displacement, neighborhood_grid_radius, neighborhood_grid_width, kernel_radius,
+ stride_1, stride_2,
+ in_width, in_height, padded_in_width, padded_in_height, in_channels, in_count_per_sample, pad,
+ output_b_gradient, input_a, gradient);
+ }
+}
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/Codes/flownet2/src/ops/correlation/correlation_kernel.cc b/Codes/flownet2/src/ops/correlation/correlation_kernel.cc
new file mode 100644
index 0000000..f8a5193
--- /dev/null
+++ b/Codes/flownet2/src/ops/correlation/correlation_kernel.cc
@@ -0,0 +1,137 @@
+#define EIGEN_USE_THREADS
+
+#include <utility>
+
+#include "correlation_kernel.h"
+#include "pad.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+template<typename Device>
+class CorrelationKernel : public OpKernel {
+ public:
+ explicit CorrelationKernel(OpKernelConstruction *ctx) : OpKernel(ctx) {
+ // Get the attributes
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("kernel_size", &kernel_size));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max_displacement", &max_displacement));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("stride_1", &stride_1));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("stride_2", &stride_2));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pad", &pad));
+
+ OP_REQUIRES(ctx, kernel_size % 2 != 0, errors::InvalidArgument("kernel_size must be odd"));
+ }
+
+ void Compute(OpKernelContext *ctx) override {
+ // Get the input images and transforms and verify their dimensions
+ const Tensor& input_a_t = ctx->input(0);
+ const Tensor& input_b_t = ctx->input(1);
+
+ OP_REQUIRES(ctx, input_a_t.dims() == 4, errors::InvalidArgument("input_a must have rank 4"));
+ OP_REQUIRES(ctx, input_b_t.dims() == 4, errors::InvalidArgument("input_b must have rank 4"));
+
+ // Get dimensions of input (already padded)
+ int batch_size = input_a_t.dim_size(0);
+ int input_height = input_a_t.dim_size(1);
+ int input_width = input_a_t.dim_size(2);
+ int input_channels = input_a_t.dim_size(3);
+ int padded_height = input_height + 2 * pad;
+ int padded_width = input_width + 2 * pad;
+
+ // The size of unreachable border region on each side
+ int kernel_radius = (kernel_size - 1) / 2;
+ int border_size = max_displacement + kernel_radius;
+
+ // Calculate the output dimensions
+ int output_height = ceil((float)(padded_height - border_size * 2) / (float)stride_1);
+ int output_width = ceil((float)(padded_width - border_size * 2) / (float)stride_1);
+
+ OP_REQUIRES(ctx, output_height >= 1,
+ errors::InvalidArgument("Neighborhood and kernel don't fit in input height."));
+ OP_REQUIRES(ctx, output_width >= 1,
+ errors::InvalidArgument("Neighborhood and kernel don't fit in input width."));
+
+ int neighborhood_grid_radius = max_displacement / stride_2;
+ int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
+ int output_channels = neighborhood_grid_width * neighborhood_grid_width;
+
+ // Allocate the memory for the output
+ Tensor *output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(
+ 0,
+ TensorShape({ batch_size, output_height, output_width, output_channels }),
+ &output_t));
+
+ // Get the tensors
+ auto input_a = input_a_t.tensor<float, 4>();
+ auto input_b = input_b_t.tensor<float, 4>();
+ auto output = output_t->tensor<float, 4>();
+
+ // Create temporary tensors for padded inputs
+ Tensor padded_input_a_t, padded_input_b_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<float>::value,
+ TensorShape({ batch_size, padded_height, padded_width, input_channels }),
+ &padded_input_a_t));
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<float>::value,
+ TensorShape({ batch_size, padded_height, padded_width, input_channels }),
+ &padded_input_b_t));
+ auto padded_input_a = padded_input_a_t.tensor<float, 4>();
+ auto padded_input_b = padded_input_b_t.tensor<float, 4>();
+
+ // Pad the inputs
+ Pad(ctx->eigen_device<Device>(),
+ input_a.data(),
+ batch_size,
+ input_height,
+ input_width,
+ input_channels,
+ padded_height,
+ padded_width,
+ padded_input_a.data());
+ Pad(ctx->eigen_device<Device>(),
+ input_b.data(),
+ batch_size,
+ input_height,
+ input_width,
+ input_channels,
+ padded_height,
+ padded_width,
+ padded_input_b.data());
+
+ // Perform cross correlation
+ Correlation(ctx->eigen_device<Device>(),
+ padded_input_a.data(),
+ padded_input_b.data(),
+ batch_size,
+ output_height,
+ output_width,
+ output_channels,
+ output_height * output_width * output_channels,
+ padded_height,
+ padded_width,
+ input_channels,
+ max_displacement,
+ neighborhood_grid_radius,
+ neighborhood_grid_width,
+ kernel_radius,
+ kernel_size,
+ stride_1,
+ stride_2,
+ output.data());
+ }
+
+ private:
+ int kernel_size;
+ int max_displacement;
+ int stride_1;
+ int stride_2;
+ int pad;
+};
+
+REGISTER_KERNEL_BUILDER(Name("Correlation")
+ .Device(DEVICE_GPU),
+ CorrelationKernel<GPUDevice>)
+} // end namespace tensorflow
diff --git a/Codes/flownet2/src/ops/correlation/correlation_kernel.cu.cc b/Codes/flownet2/src/ops/correlation/correlation_kernel.cu.cc
new file mode 100644
index 0000000..c63e489
--- /dev/null
+++ b/Codes/flownet2/src/ops/correlation/correlation_kernel.cu.cc
@@ -0,0 +1,153 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#define WARPS_PER_BLOCK 1
+#define THREADS_PER_WARP 32
+
+#include <stdio.h>
+#include <iostream>
+
+#include "correlation_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+__global__ void CorrelateData(int batch_size,
+ int out_width,
+ int out_height,
+ int out_channels,
+ int out_count,
+ int max_displacement,
+ int neighborhood_grid_radius,
+ int neighborhood_grid_width,
+ int kernel_radius,
+ int kernel_size,
+ int stride_1,
+ int stride_2,
+ int in_width_padded,
+ int in_height_padded,
+ int in_channels,
+ const float *input_a,
+ const float *input_b,
+ float *output) {
+ extern __shared__ char patch_data_char[];
+
+ float *patch_data = (float *)patch_data_char;
+
+ // First (upper left) position of kernel upper-left corner in current center
+ // position of neighborhood in image 1
+ int x1 = blockIdx.x * stride_1 + max_displacement;
+ int y1 = blockIdx.y * stride_1 + max_displacement;
+ int item = blockIdx.z;
+ int ch_off = threadIdx.x;
+
+ // Load 3D patch into shared shared memory
+ // HEIGHT
+ for (int j = 0; j < kernel_size; j++) {
+ // WIDTH
+ for (int i = 0; i < kernel_size; i++) {
+ int ji_off = ((j * kernel_size) + i) * in_channels;
+
+ // CHANNELS
+ for (int ch = ch_off; ch < in_channels; ch += (WARPS_PER_BLOCK * THREADS_PER_WARP)) {
+ int idx1 = ((item * in_height_padded + y1 + j) * in_width_padded + x1 + i) *
+ in_channels + ch;
+ int idxPatchData = ji_off + ch;
+ patch_data[idxPatchData] = input_a[idx1];
+ }
+ }
+ }
+
+ __syncthreads();
+
+ __shared__ float sum[WARPS_PER_BLOCK * THREADS_PER_WARP];
+
+ // Compute correlation
+ for (int out_channel = 0; out_channel < out_channels; out_channel++) {
+ sum[ch_off] = 0;
+
+ int s2o = (out_channel % neighborhood_grid_width - neighborhood_grid_radius) * stride_2;
+ int s2p = (out_channel / neighborhood_grid_width - neighborhood_grid_radius) * stride_2;
+ int x2 = x1 + s2o;
+ int y2 = y1 + s2p;
+
+ // HEIGHT
+ for (int j = 0; j < kernel_size; j++) {
+ // WIDTH
+ for (int i = 0; i < kernel_size; i++) {
+ int ji_off = ((j * kernel_size) + i) * in_channels;
+
+ // CHANNELS
+ for (int ch = ch_off; ch < in_channels; ch += (WARPS_PER_BLOCK * THREADS_PER_WARP)) {
+ int idxPatchData = ji_off + ch;
+ int idx2 = ((item * in_height_padded + y2 + j) * in_width_padded + x2 + i) *
+ in_channels + ch;
+
+ sum[ch_off] += patch_data[idxPatchData] * input_b[idx2];
+ }
+ }
+ }
+
+ __syncthreads();
+
+ if (ch_off == 0) {
+ float total_sum = 0;
+
+ for (int idx = 0; idx < WARPS_PER_BLOCK * THREADS_PER_WARP; idx++) {
+ total_sum += sum[idx];
+ }
+ const int sumelems = kernel_size * kernel_size * in_channels;
+ const int index = (blockIdx.y * out_width + blockIdx.x) * out_channels + out_channel;
+
+ /* from Caffe: const int index = ((out_channel * out_height +
+ blockIdx.y) * out_width) + blockIdx.x; */
+ output[index + item * out_count] = total_sum / (float)sumelems;
+
+ // Caffe, NKHW: ((n * K + k) * H + h) * W + w at point (n, k, h, w)
+ // TF, NHWK: ((n * H + h) * W + w) * K + k at point (n, h, w, k)
+ // n = 0
+ // caffe: ((k * H + h) * W + w) + n * K * H * W
+ // tf: (h * W + w) * K + k + n * H * W * K
+ }
+ }
+}
+
+void Correlation(const GPUDevice& device,
+ const float *input_a,
+ const float *input_b,
+ const int batch_size,
+ const int out_height,
+ const int out_width,
+ const int out_channels,
+ const int out_count,
+ const int in_height_padded,
+ const int in_width_padded,
+ const int in_channels,
+ int max_displacement,
+ int neighborhood_grid_radius,
+ int neighborhood_grid_width,
+ int kernel_radius,
+ int kernel_size,
+ int stride_1,
+ int stride_2,
+ float *output) {
+ dim3 totalBlocksCorr(out_width, out_height, batch_size);
+ dim3 threadsPerBlock(THREADS_PER_WARP *WARPS_PER_BLOCK);
+ const int shared_memory_per_block = (kernel_size * kernel_size) * in_channels;
+
+ CorrelateData << < totalBlocksCorr, threadsPerBlock, shared_memory_per_block * sizeof(float),
+ device.stream() >> > (
+ batch_size, out_width, out_height, out_channels, out_count,
+ max_displacement, neighborhood_grid_radius, neighborhood_grid_width, kernel_radius,
+ kernel_size, stride_1, stride_2, in_width_padded, in_height_padded, in_channels,
+ input_a, input_b, output);
+}
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/Codes/flownet2/src/ops/correlation/correlation_kernel.h b/Codes/flownet2/src/ops/correlation/correlation_kernel.h
new file mode 100644
index 0000000..a1dfb62
--- /dev/null
+++ b/Codes/flownet2/src/ops/correlation/correlation_kernel.h
@@ -0,0 +1,77 @@
+#ifndef FLOWNET_CORRELATION_H_
+#define FLOWNET_CORRELATION_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+void Correlation(const GPUDevice& device,
+ const float *input_a,
+ const float *input_b,
+ const int batch_size,
+ const int out_height,
+ const int out_width,
+ const int out_channels,
+ const int out_count,
+ const int in_height_padded,
+ const int in_width_padded,
+ const int in_channels,
+ int max_displacement,
+ int neighborhood_grid_radius,
+ int neighborhood_grid_width,
+ int kernel_radius,
+ int kernel_size,
+ int stride_1,
+ int stride_2,
+ float *output);
+
+
+void CorrelationGradA(const GPUDevice& device,
+ const int batch_size,
+ const int out_width,
+ const int out_height,
+ const int out_channels,
+ const int max_displacement,
+ const int neighborhood_grid_radius,
+ const int neighborhood_grid_width,
+ const int kernel_radius,
+ const int stride_1,
+ const int stride_2,
+ const int in_width,
+ const int in_height,
+ const int padded_in_width,
+ const int padded_in_height,
+ const int in_channels,
+ const int in_count_per_sample,
+ const int pad,
+ const float *input_b,
+ const float *gradient,
+ float *output_a_gradient);
+
+void CorrelationGradB(const GPUDevice& device,
+ const int batch_size,
+ const int out_width,
+ const int out_height,
+ const int out_channels,
+ const int max_displacement,
+ const int neighborhood_grid_radius,
+ const int neighborhood_grid_width,
+ const int kernel_radius,
+ const int stride_1,
+ const int stride_2,
+ const int in_width,
+ const int in_height,
+ const int padded_in_width,
+ const int padded_in_height,
+ const int in_channels,
+ const int in_count_per_sample,
+ const int pad,
+ const float *input_a,
+ const float *gradient,
+ float *output_b_gradient);
+} // end namespace tensorflow
+
+#endif // FLOWNET_CORRELATION_H_
diff --git a/Codes/flownet2/src/ops/correlation/correlation_op.cc b/Codes/flownet2/src/ops/correlation/correlation_op.cc
new file mode 100644
index 0000000..4f420f0
--- /dev/null
+++ b/Codes/flownet2/src/ops/correlation/correlation_op.cc
@@ -0,0 +1,83 @@
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+Status SetOutput(InferenceContext *c) {
+ ShapeHandle input_a, input_b, input;
+
+ // Get shapes of both inputs and verify they are rank 4
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_a));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &input_b));
+
+ // Verify inputs are same dimensions
+ TF_RETURN_IF_ERROR(c->Merge(input_a, input_b, &input));
+
+ // Get the attributes
+ int kernel_size, max_displacement, stride_1, stride_2, pad;
+ TF_RETURN_IF_ERROR(c->GetAttr("kernel_size", &kernel_size));
+ TF_RETURN_IF_ERROR(c->GetAttr("max_displacement", &max_displacement));
+ TF_RETURN_IF_ERROR(c->GetAttr("stride_1", &stride_1));
+ TF_RETURN_IF_ERROR(c->GetAttr("stride_2", &stride_2));
+ TF_RETURN_IF_ERROR(c->GetAttr("pad", &pad));
+
+ // Get dimensions of input (already padded)
+ int64 batch = c->Value(c->Dim(input, 0));
+ int64 input_height = c->Value(c->Dim(input, 1));
+ int64 input_width = c->Value(c->Dim(input, 2));
+ int64 padded_height = input_height + 2 * pad;
+ int64 padded_width = input_width + 2 * pad;
+
+ // The size of unreachable border region on each side
+ int kernel_radius = (kernel_size - 1) / 2;
+ int border_size = max_displacement + kernel_radius;
+
+ // Calculate the output dimensions
+ int64 output_height = (int64)ceil((float)(padded_height - border_size * 2) / (float)stride_1);
+ int64 output_width = (int64)ceil((float)(padded_width - border_size * 2) / (float)stride_1);
+
+ // TODO: Verify output size >= 1
+
+ int neighborhood_grid_radius = max_displacement / stride_2;
+ int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
+ int64 output_channels = neighborhood_grid_width * neighborhood_grid_width;
+
+ // Set output shape
+ c->set_output(0, c->MakeShape({ batch, output_height, output_width, output_channels }));
+ return Status::OK();
+}
+
+REGISTER_OP("Correlation")
+.Input("input_a: float32")
+.Input("input_b: float32")
+.Attr("kernel_size: int")
+.Attr("max_displacement: int")
+.Attr("stride_1: int")
+.Attr("stride_2: int")
+.Attr("pad: int")
+.Output("output: float32")
+.SetShapeFn(SetOutput);
+
+REGISTER_OP("CorrelationGrad")
+.Input("gradients: float32")
+.Input("input_a: float32")
+.Input("input_b: float32")
+.Attr("kernel_size: int")
+.Attr("max_displacement: int")
+.Attr("stride_1: int")
+.Attr("stride_2: int")
+.Attr("pad: int")
+.Output("backprops_a: float32")
+.Output("backprops_b: float32")
+.SetShapeFn([](InferenceContext *c) {
+ // Output gradients should be the same dimensions as the inputs
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->input(2), &out));
+ c->set_output(0, out);
+ c->set_output(1, out);
+ return Status::OK();
+ });
+} // namespace tensorflow
diff --git a/Codes/flownet2/src/ops/correlation/pad.cu.cc b/Codes/flownet2/src/ops/correlation/pad.cu.cc
new file mode 100644
index 0000000..0b6c93d
--- /dev/null
+++ b/Codes/flownet2/src/ops/correlation/pad.cu.cc
@@ -0,0 +1,76 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+#include <iostream>
+
+#include "pad.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+__global__ void PadData(
+ const float *in,
+ int in_widthheight,
+ int in_width,
+ int in_height,
+ int out_width,
+ int out_height,
+ int channels,
+ int padding,
+ float *out) {
+ int xy = blockIdx.x * blockDim.x + threadIdx.x;
+
+ int x = xy % in_width;
+ int y = xy / in_width;
+ int ch = blockIdx.y;
+ int n = blockIdx.z;
+
+ if (xy >= in_widthheight) {
+ out[((n * out_height + y) * out_width + x) * channels + ch] = 0.0;
+ return;
+ }
+
+ float value = in[((n * in_height + y) * in_width + x) * channels + ch];
+
+ __syncthreads();
+
+ int xpad = x + padding;
+ int ypad = y + padding;
+
+ out[((n * out_height + ypad) * out_width + xpad) * channels + ch] = value;
+}
+
+void Pad(const GPUDevice& device,
+ const float *input,
+ int batch_size,
+ int input_height,
+ int input_width,
+ int input_channels,
+ int output_height,
+ int output_width,
+ float *output) {
+ int in_widthheight = input_width * input_height;
+ int threads_per_block = 16;
+ dim3 totalBlocks((in_widthheight - 1) / threads_per_block + 1, input_channels, batch_size);
+
+ cudaMemset(output, 0, batch_size * output_height * output_width * input_channels * sizeof(float));
+
+ int padding = (output_height - input_height) / 2;
+
+ // LAUNCH KERNEL
+ PadData << < totalBlocks, threads_per_block, 0, device.stream() >> > (
+ input,
+ in_widthheight,
+ input_width,
+ input_height,
+ output_width,
+ output_height,
+ input_channels,
+ padding,
+ output);
+}
+}
+#endif // if GOOGLE_CUDA
diff --git a/Codes/flownet2/src/ops/correlation/pad.h b/Codes/flownet2/src/ops/correlation/pad.h
new file mode 100644
index 0000000..afb4df0
--- /dev/null
+++ b/Codes/flownet2/src/ops/correlation/pad.h
@@ -0,0 +1,20 @@
+#ifndef FLOWNET_PAD_H_
+#define FLOWNET_PAD_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+void Pad(const GPUDevice& device,
+ const float *input,
+ int batch_size,
+ int input_height,
+ int input_width,
+ int input_channels,
+ int output_height,
+ int output_width,
+ float *output);
+} // end namespace tensorflow
+
+#endif // ifndef FLOWNET_PAD_H_