diff options
Diffstat (limited to 'Codes/flownet2/src/ops/correlation')
8 files changed, 968 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cc b/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cc new file mode 100644 index 0000000..4e92f45 --- /dev/null +++ b/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cc @@ -0,0 +1,160 @@ +#define EIGEN_USE_THREADS + +#include "correlation_kernel.h" +#include "pad.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +template<typename Device> +class CorrelationGradKernel : public OpKernel { + public: + explicit CorrelationGradKernel(OpKernelConstruction *ctx) : OpKernel(ctx) { + // Get the attributes + OP_REQUIRES_OK(ctx, ctx->GetAttr("kernel_size", &kernel_size)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_displacement", &max_displacement)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("stride_1", &stride_1)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("stride_2", &stride_2)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("pad", &pad)); + + OP_REQUIRES(ctx, kernel_size % 2 != 0, errors::InvalidArgument("kernel_size must be odd")); + } + + void Compute(OpKernelContext *ctx) override { + // Get the input images and verify their dimensions + const Tensor& gradients_t = ctx->input(0); + const Tensor& input_a_t = ctx->input(1); + const Tensor& input_b_t = ctx->input(2); + + OP_REQUIRES(ctx, input_a_t.dims() == 4, errors::InvalidArgument("input_a must have rank 4")); + OP_REQUIRES(ctx, input_b_t.dims() == 4, errors::InvalidArgument("input_b must have rank 4")); + + // Get dimensions of input + const int batch_size = input_a_t.dim_size(0); + const int in_height = input_a_t.dim_size(1); + const int in_width = input_a_t.dim_size(2); + const int in_channels = input_a_t.dim_size(3); + const int in_count_per_sample = in_height * in_width * in_channels; + const int padded_height = in_height + 2 * pad; + const int padded_width = in_width + 2 * pad; + + // The size of unreachable border region on each side + const int kernel_radius = (kernel_size - 1) / 2; + const int border_size = max_displacement + kernel_radius; + + // Calculate the output dimensions + const int out_height = ceil((float)(padded_height - border_size * 2) / (float)stride_1); + const int out_width = ceil((float)(padded_width - border_size * 2) / (float)stride_1); + + const int neighborhood_grid_radius = max_displacement / stride_2; + const int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + const int out_channels = neighborhood_grid_width * neighborhood_grid_width; + + // Allocate the memory for the outputs + Tensor *output_a_gradient_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input_a_t.shape(), &output_a_gradient_t)); + Tensor *output_b_gradient_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, input_b_t.shape(), &output_b_gradient_t)); + + // Get the tensors + auto gradients = gradients_t.tensor<float, 4>(); + auto input_a = input_a_t.tensor<float, 4>(); + auto input_b = input_b_t.tensor<float, 4>(); + auto output_a_gradient = output_a_gradient_t->tensor<float, 4>(); + auto output_b_gradient = output_b_gradient_t->tensor<float, 4>(); + + // Create temporary tensors for padded inputs + Tensor padded_input_a_t, padded_input_b_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<float>::value, + TensorShape({ batch_size, padded_height, padded_width, in_channels }), + &padded_input_a_t)); + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<float>::value, + TensorShape({ batch_size, padded_height, padded_width, in_channels }), + &padded_input_b_t)); + auto padded_input_a = padded_input_a_t.tensor<float, 4>(); + auto padded_input_b = padded_input_b_t.tensor<float, 4>(); + + // Pad the inputs + Pad(ctx->eigen_device<Device>(), + input_a.data(), + batch_size, + in_height, + in_width, + in_channels, + padded_height, + padded_width, + padded_input_a.data()); + Pad(ctx->eigen_device<Device>(), + input_b.data(), + batch_size, + in_height, + in_width, + in_channels, + padded_height, + padded_width, + padded_input_b.data()); + + CorrelationGradA(ctx->eigen_gpu_device(), + batch_size, + out_width, + out_height, + out_channels, + max_displacement, + neighborhood_grid_radius, + neighborhood_grid_width, + kernel_radius, + stride_1, + stride_2, + in_width, + in_height, + padded_width, + padded_height, + in_channels, + in_count_per_sample, + pad, + padded_input_b.data(), + gradients.data(), + output_a_gradient.data()); + + CorrelationGradB(ctx->eigen_gpu_device(), + batch_size, + out_width, + out_height, + out_channels, + max_displacement, + neighborhood_grid_radius, + neighborhood_grid_width, + kernel_radius, + stride_1, + stride_2, + in_width, + in_height, + padded_width, + padded_height, + in_channels, + in_count_per_sample, + pad, + padded_input_a.data(), + gradients.data(), + output_b_gradient.data()); + } + + private: + int kernel_size; + int max_displacement; + int stride_1; + int stride_2; + int pad; +}; + +REGISTER_KERNEL_BUILDER(Name("CorrelationGrad") + .Device(DEVICE_GPU), + CorrelationGradKernel<GPUDevice>) +} // end namespace tensorflow diff --git a/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cu.cc b/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cu.cc new file mode 100644 index 0000000..19e3a40 --- /dev/null +++ b/Codes/flownet2/src/ops/correlation/correlation_grad_kernel.cu.cc @@ -0,0 +1,262 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#define ROUND_OFF 50000 + +#include <stdio.h> +#include <iostream> + +#include "correlation_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +__global__ void CorrelateDataBackward0(const int nthreads, + int item, + int out_width, + int out_height, + int out_channels, + int max_displacement, + int neighborhood_grid_radius, + int neighborhood_grid_width, + int kernel_radius, + int stride_1, + int stride_2, + int in_width, + int in_height, + int padded_in_width, + int padded_in_height, + int in_channels, + int in_count_per_sample, + int pad_size, + float *output_a_gradient, + const float *input_b, + const float *gradient) +{ + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int k = index % in_channels; // channels + int x = (index / in_channels) % in_width + pad_size; // w-pos + int y = (index / in_channels / in_width) % in_height + pad_size; // h-pos + + // Get X,Y ranges and clamp + // round_off is a trick to enable integer division with ceil, even for + // negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = stride_1 * round_off; + + // We add round_off before_s1 the int division and subtract round_off after + // it, to ensure the formula matches ceil behavior: + int xmin = (x - 2 * kernel_radius - max_displacement + round_off_s1 - 1) / stride_1 + 1 - + round_off; + int ymin = (y - 2 * kernel_radius - max_displacement + round_off_s1 - 1) / stride_1 + 1 - + round_off; + + // Same here: + int xmax = (x - max_displacement + round_off_s1) / stride_1 - round_off; + int ymax = (y - max_displacement + round_off_s1) / stride_1 - round_off; + + float sum = 0; + + if ((xmax >= 0) && (ymax >= 0) && (xmin <= out_width - 1) && (ymin <= out_height - 1)) { + xmin = max(0, xmin); + xmax = min(out_width - 1, xmax); + + ymin = max(0, ymin); + ymax = min(out_height - 1, ymax); + + for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) { + for (int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) { + // Get input_b data: + int s2o = stride_2 * o; + int s2p = stride_2 * p; + int idx_input_b = ((item * padded_in_height + (y + s2p)) * padded_in_width + (x + s2o)) * + in_channels + k; + float input_b_tmp = input_b[idx_input_b]; // input_b[x+s2o,y+s2p,k] + + // Index offset for gradient in following loops: + int op = (p + neighborhood_grid_radius) * neighborhood_grid_width + + (o + neighborhood_grid_radius); // index [o,p] + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + // gradient[x,y,o,p] + int idx_gradient = ((item * out_height + y) * out_width + x) * out_channels + op; + sum += gradient[idx_gradient] * input_b_tmp; + } + } + } + } + } + const int sumelems = (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * in_channels; + const int input_a_idx = ((y - pad_size) * in_width + (x - pad_size)) * in_channels + k; + output_a_gradient[input_a_idx + item * in_count_per_sample] = sum / (float)sumelems; + } +} + +__global__ void CorrelateDataBackward1(const int nthreads, + int item, + int out_width, + int out_height, + int out_channels, + int max_displacement, + int neighborhood_grid_radius, + int neighborhood_grid_width, + int kernel_radius, + int stride_1, + int stride_2, + int in_width, + int in_height, + int padded_in_width, + int padded_in_height, + int in_channels, + int in_count_per_sample, + int pad_size, + float *output_b_gradient, + const float *input_a, + const float *gradient) +{ + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int k = index % in_channels; // channels + int x = (index / in_channels) % in_width + pad_size; // w-pos + int y = (index / in_channels / in_width) % in_height + pad_size; // h-pos + + // round_off is a trick to enable integer division with ceil, even for + // negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = stride_1 * round_off; + + float sum = 0; + + // Height (y) + for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) { + // Width (x) + for (int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) { + int s2o = stride_2 * o; + int s2p = stride_2 * p; + + // Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off + // after it, to ensure the formula matches ceil behavior: + int xmin = (x - 2 * kernel_radius - max_displacement - s2o + round_off_s1 - 1) / stride_1 + + 1 - round_off; + int ymin = (y - 2 * kernel_radius - max_displacement - s2p + round_off_s1 - 1) / stride_1 + + 1 - round_off; + + // Caffe, NKHW: ((n * K + k) * H + h) * W + w at point (n, k, h, w) + // TF, NHWK: ((n * H + h) * W + w) * K + k at point (n, h, w, k) + + // Same here: + int xmax = (x - max_displacement - s2o + round_off_s1) / stride_1 - round_off; + int ymax = (y - max_displacement - s2p + round_off_s1) / stride_1 - round_off; + + if ((xmax >= 0) && (ymax >= 0) && (xmin <= out_width - 1) && (ymin <= out_height - 1)) { + xmin = max(0, xmin); + xmax = min(out_width - 1, xmax); + + ymin = max(0, ymin); + ymax = min(out_height - 1, ymax); + + // Get input_a data: + int idx_input_a = ((item * padded_in_height + (y - s2p)) * padded_in_width + (x - s2o)) * + in_channels + k; + float input_a_tmp = input_a[idx_input_a]; + + // Index offset for gradient in following loops: + int op = (p + neighborhood_grid_radius) * neighborhood_grid_width + + (o + neighborhood_grid_radius); // index [o,p] + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idx_gradient = ((item * out_height + y) * out_width + x) * out_channels + op; + sum += gradient[idx_gradient] * input_a_tmp; + } + } + } + } + } + const int sumelems = (kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * in_channels; + const int input_b_idx = ((y - pad_size) * in_width + (x - pad_size)) * in_channels + k; + output_b_gradient[input_b_idx + item * in_count_per_sample] = sum / (float)sumelems; + } +} + +void CorrelationGradA(const GPUDevice& device, + const int batch_size, + const int out_width, + const int out_height, + const int out_channels, + const int max_displacement, + const int neighborhood_grid_radius, + const int neighborhood_grid_width, + const int kernel_radius, + const int stride_1, + const int stride_2, + const int in_width, + const int in_height, + const int padded_in_width, + const int padded_in_height, + const int in_channels, + const int in_count_per_sample, // h * w * ch + const int pad, + const float *input_b, + const float *gradient, + float *output_a_gradient) { + CudaLaunchConfig config = GetCudaLaunchConfig(in_count_per_sample, device); + + for (int n = 0; n < batch_size; n++) { + CorrelateDataBackward0 << < config.block_count, config.thread_per_block, 0, + device.stream() >> > ( + in_count_per_sample, + n, out_width, out_height, out_channels, + max_displacement, neighborhood_grid_radius, neighborhood_grid_width, kernel_radius, + stride_1, stride_2, + in_width, in_height, padded_in_width, padded_in_height, in_channels, in_count_per_sample, pad, + output_a_gradient, input_b, gradient); + } +} + +void CorrelationGradB(const GPUDevice& device, + const int batch_size, + const int out_width, + const int out_height, + const int out_channels, + const int max_displacement, + const int neighborhood_grid_radius, + const int neighborhood_grid_width, + const int kernel_radius, + const int stride_1, + const int stride_2, + const int in_width, + const int in_height, + const int padded_in_width, + const int padded_in_height, + const int in_channels, + const int in_count_per_sample, + const int pad, + const float *input_a, + const float *gradient, + float *output_b_gradient) { + CudaLaunchConfig config = GetCudaLaunchConfig(in_count_per_sample, device); + + for (int n = 0; n < batch_size; n++) { + CorrelateDataBackward1 << < config.block_count, config.thread_per_block, 0, + device.stream() >> > ( + in_count_per_sample, + n, out_width, out_height, out_channels, + max_displacement, neighborhood_grid_radius, neighborhood_grid_width, kernel_radius, + stride_1, stride_2, + in_width, in_height, padded_in_width, padded_in_height, in_channels, in_count_per_sample, pad, + output_b_gradient, input_a, gradient); + } +} +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/Codes/flownet2/src/ops/correlation/correlation_kernel.cc b/Codes/flownet2/src/ops/correlation/correlation_kernel.cc new file mode 100644 index 0000000..f8a5193 --- /dev/null +++ b/Codes/flownet2/src/ops/correlation/correlation_kernel.cc @@ -0,0 +1,137 @@ +#define EIGEN_USE_THREADS + +#include <utility> + +#include "correlation_kernel.h" +#include "pad.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +template<typename Device> +class CorrelationKernel : public OpKernel { + public: + explicit CorrelationKernel(OpKernelConstruction *ctx) : OpKernel(ctx) { + // Get the attributes + OP_REQUIRES_OK(ctx, ctx->GetAttr("kernel_size", &kernel_size)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_displacement", &max_displacement)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("stride_1", &stride_1)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("stride_2", &stride_2)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("pad", &pad)); + + OP_REQUIRES(ctx, kernel_size % 2 != 0, errors::InvalidArgument("kernel_size must be odd")); + } + + void Compute(OpKernelContext *ctx) override { + // Get the input images and transforms and verify their dimensions + const Tensor& input_a_t = ctx->input(0); + const Tensor& input_b_t = ctx->input(1); + + OP_REQUIRES(ctx, input_a_t.dims() == 4, errors::InvalidArgument("input_a must have rank 4")); + OP_REQUIRES(ctx, input_b_t.dims() == 4, errors::InvalidArgument("input_b must have rank 4")); + + // Get dimensions of input (already padded) + int batch_size = input_a_t.dim_size(0); + int input_height = input_a_t.dim_size(1); + int input_width = input_a_t.dim_size(2); + int input_channels = input_a_t.dim_size(3); + int padded_height = input_height + 2 * pad; + int padded_width = input_width + 2 * pad; + + // The size of unreachable border region on each side + int kernel_radius = (kernel_size - 1) / 2; + int border_size = max_displacement + kernel_radius; + + // Calculate the output dimensions + int output_height = ceil((float)(padded_height - border_size * 2) / (float)stride_1); + int output_width = ceil((float)(padded_width - border_size * 2) / (float)stride_1); + + OP_REQUIRES(ctx, output_height >= 1, + errors::InvalidArgument("Neighborhood and kernel don't fit in input height.")); + OP_REQUIRES(ctx, output_width >= 1, + errors::InvalidArgument("Neighborhood and kernel don't fit in input width.")); + + int neighborhood_grid_radius = max_displacement / stride_2; + int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + int output_channels = neighborhood_grid_width * neighborhood_grid_width; + + // Allocate the memory for the output + Tensor *output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 0, + TensorShape({ batch_size, output_height, output_width, output_channels }), + &output_t)); + + // Get the tensors + auto input_a = input_a_t.tensor<float, 4>(); + auto input_b = input_b_t.tensor<float, 4>(); + auto output = output_t->tensor<float, 4>(); + + // Create temporary tensors for padded inputs + Tensor padded_input_a_t, padded_input_b_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<float>::value, + TensorShape({ batch_size, padded_height, padded_width, input_channels }), + &padded_input_a_t)); + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<float>::value, + TensorShape({ batch_size, padded_height, padded_width, input_channels }), + &padded_input_b_t)); + auto padded_input_a = padded_input_a_t.tensor<float, 4>(); + auto padded_input_b = padded_input_b_t.tensor<float, 4>(); + + // Pad the inputs + Pad(ctx->eigen_device<Device>(), + input_a.data(), + batch_size, + input_height, + input_width, + input_channels, + padded_height, + padded_width, + padded_input_a.data()); + Pad(ctx->eigen_device<Device>(), + input_b.data(), + batch_size, + input_height, + input_width, + input_channels, + padded_height, + padded_width, + padded_input_b.data()); + + // Perform cross correlation + Correlation(ctx->eigen_device<Device>(), + padded_input_a.data(), + padded_input_b.data(), + batch_size, + output_height, + output_width, + output_channels, + output_height * output_width * output_channels, + padded_height, + padded_width, + input_channels, + max_displacement, + neighborhood_grid_radius, + neighborhood_grid_width, + kernel_radius, + kernel_size, + stride_1, + stride_2, + output.data()); + } + + private: + int kernel_size; + int max_displacement; + int stride_1; + int stride_2; + int pad; +}; + +REGISTER_KERNEL_BUILDER(Name("Correlation") + .Device(DEVICE_GPU), + CorrelationKernel<GPUDevice>) +} // end namespace tensorflow diff --git a/Codes/flownet2/src/ops/correlation/correlation_kernel.cu.cc b/Codes/flownet2/src/ops/correlation/correlation_kernel.cu.cc new file mode 100644 index 0000000..c63e489 --- /dev/null +++ b/Codes/flownet2/src/ops/correlation/correlation_kernel.cu.cc @@ -0,0 +1,153 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#define WARPS_PER_BLOCK 1 +#define THREADS_PER_WARP 32 + +#include <stdio.h> +#include <iostream> + +#include "correlation_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +__global__ void CorrelateData(int batch_size, + int out_width, + int out_height, + int out_channels, + int out_count, + int max_displacement, + int neighborhood_grid_radius, + int neighborhood_grid_width, + int kernel_radius, + int kernel_size, + int stride_1, + int stride_2, + int in_width_padded, + int in_height_padded, + int in_channels, + const float *input_a, + const float *input_b, + float *output) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center + // position of neighborhood in image 1 + int x1 = blockIdx.x * stride_1 + max_displacement; + int y1 = blockIdx.y * stride_1 + max_displacement; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + // HEIGHT + for (int j = 0; j < kernel_size; j++) { + // WIDTH + for (int i = 0; i < kernel_size; i++) { + int ji_off = ((j * kernel_size) + i) * in_channels; + + // CHANNELS + for (int ch = ch_off; ch < in_channels; ch += (WARPS_PER_BLOCK * THREADS_PER_WARP)) { + int idx1 = ((item * in_height_padded + y1 + j) * in_width_padded + x1 + i) * + in_channels + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = input_a[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[WARPS_PER_BLOCK * THREADS_PER_WARP]; + + // Compute correlation + for (int out_channel = 0; out_channel < out_channels; out_channel++) { + sum[ch_off] = 0; + + int s2o = (out_channel % neighborhood_grid_width - neighborhood_grid_radius) * stride_2; + int s2p = (out_channel / neighborhood_grid_width - neighborhood_grid_radius) * stride_2; + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + // HEIGHT + for (int j = 0; j < kernel_size; j++) { + // WIDTH + for (int i = 0; i < kernel_size; i++) { + int ji_off = ((j * kernel_size) + i) * in_channels; + + // CHANNELS + for (int ch = ch_off; ch < in_channels; ch += (WARPS_PER_BLOCK * THREADS_PER_WARP)) { + int idxPatchData = ji_off + ch; + int idx2 = ((item * in_height_padded + y2 + j) * in_width_padded + x2 + i) * + in_channels + ch; + + sum[ch_off] += patch_data[idxPatchData] * input_b[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + + for (int idx = 0; idx < WARPS_PER_BLOCK * THREADS_PER_WARP; idx++) { + total_sum += sum[idx]; + } + const int sumelems = kernel_size * kernel_size * in_channels; + const int index = (blockIdx.y * out_width + blockIdx.x) * out_channels + out_channel; + + /* from Caffe: const int index = ((out_channel * out_height + + blockIdx.y) * out_width) + blockIdx.x; */ + output[index + item * out_count] = total_sum / (float)sumelems; + + // Caffe, NKHW: ((n * K + k) * H + h) * W + w at point (n, k, h, w) + // TF, NHWK: ((n * H + h) * W + w) * K + k at point (n, h, w, k) + // n = 0 + // caffe: ((k * H + h) * W + w) + n * K * H * W + // tf: (h * W + w) * K + k + n * H * W * K + } + } +} + +void Correlation(const GPUDevice& device, + const float *input_a, + const float *input_b, + const int batch_size, + const int out_height, + const int out_width, + const int out_channels, + const int out_count, + const int in_height_padded, + const int in_width_padded, + const int in_channels, + int max_displacement, + int neighborhood_grid_radius, + int neighborhood_grid_width, + int kernel_radius, + int kernel_size, + int stride_1, + int stride_2, + float *output) { + dim3 totalBlocksCorr(out_width, out_height, batch_size); + dim3 threadsPerBlock(THREADS_PER_WARP *WARPS_PER_BLOCK); + const int shared_memory_per_block = (kernel_size * kernel_size) * in_channels; + + CorrelateData << < totalBlocksCorr, threadsPerBlock, shared_memory_per_block * sizeof(float), + device.stream() >> > ( + batch_size, out_width, out_height, out_channels, out_count, + max_displacement, neighborhood_grid_radius, neighborhood_grid_width, kernel_radius, + kernel_size, stride_1, stride_2, in_width_padded, in_height_padded, in_channels, + input_a, input_b, output); +} +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/Codes/flownet2/src/ops/correlation/correlation_kernel.h b/Codes/flownet2/src/ops/correlation/correlation_kernel.h new file mode 100644 index 0000000..a1dfb62 --- /dev/null +++ b/Codes/flownet2/src/ops/correlation/correlation_kernel.h @@ -0,0 +1,77 @@ +#ifndef FLOWNET_CORRELATION_H_ +#define FLOWNET_CORRELATION_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +void Correlation(const GPUDevice& device, + const float *input_a, + const float *input_b, + const int batch_size, + const int out_height, + const int out_width, + const int out_channels, + const int out_count, + const int in_height_padded, + const int in_width_padded, + const int in_channels, + int max_displacement, + int neighborhood_grid_radius, + int neighborhood_grid_width, + int kernel_radius, + int kernel_size, + int stride_1, + int stride_2, + float *output); + + +void CorrelationGradA(const GPUDevice& device, + const int batch_size, + const int out_width, + const int out_height, + const int out_channels, + const int max_displacement, + const int neighborhood_grid_radius, + const int neighborhood_grid_width, + const int kernel_radius, + const int stride_1, + const int stride_2, + const int in_width, + const int in_height, + const int padded_in_width, + const int padded_in_height, + const int in_channels, + const int in_count_per_sample, + const int pad, + const float *input_b, + const float *gradient, + float *output_a_gradient); + +void CorrelationGradB(const GPUDevice& device, + const int batch_size, + const int out_width, + const int out_height, + const int out_channels, + const int max_displacement, + const int neighborhood_grid_radius, + const int neighborhood_grid_width, + const int kernel_radius, + const int stride_1, + const int stride_2, + const int in_width, + const int in_height, + const int padded_in_width, + const int padded_in_height, + const int in_channels, + const int in_count_per_sample, + const int pad, + const float *input_a, + const float *gradient, + float *output_b_gradient); +} // end namespace tensorflow + +#endif // FLOWNET_CORRELATION_H_ diff --git a/Codes/flownet2/src/ops/correlation/correlation_op.cc b/Codes/flownet2/src/ops/correlation/correlation_op.cc new file mode 100644 index 0000000..4f420f0 --- /dev/null +++ b/Codes/flownet2/src/ops/correlation/correlation_op.cc @@ -0,0 +1,83 @@ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +Status SetOutput(InferenceContext *c) { + ShapeHandle input_a, input_b, input; + + // Get shapes of both inputs and verify they are rank 4 + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_a)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &input_b)); + + // Verify inputs are same dimensions + TF_RETURN_IF_ERROR(c->Merge(input_a, input_b, &input)); + + // Get the attributes + int kernel_size, max_displacement, stride_1, stride_2, pad; + TF_RETURN_IF_ERROR(c->GetAttr("kernel_size", &kernel_size)); + TF_RETURN_IF_ERROR(c->GetAttr("max_displacement", &max_displacement)); + TF_RETURN_IF_ERROR(c->GetAttr("stride_1", &stride_1)); + TF_RETURN_IF_ERROR(c->GetAttr("stride_2", &stride_2)); + TF_RETURN_IF_ERROR(c->GetAttr("pad", &pad)); + + // Get dimensions of input (already padded) + int64 batch = c->Value(c->Dim(input, 0)); + int64 input_height = c->Value(c->Dim(input, 1)); + int64 input_width = c->Value(c->Dim(input, 2)); + int64 padded_height = input_height + 2 * pad; + int64 padded_width = input_width + 2 * pad; + + // The size of unreachable border region on each side + int kernel_radius = (kernel_size - 1) / 2; + int border_size = max_displacement + kernel_radius; + + // Calculate the output dimensions + int64 output_height = (int64)ceil((float)(padded_height - border_size * 2) / (float)stride_1); + int64 output_width = (int64)ceil((float)(padded_width - border_size * 2) / (float)stride_1); + + // TODO: Verify output size >= 1 + + int neighborhood_grid_radius = max_displacement / stride_2; + int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + int64 output_channels = neighborhood_grid_width * neighborhood_grid_width; + + // Set output shape + c->set_output(0, c->MakeShape({ batch, output_height, output_width, output_channels })); + return Status::OK(); +} + +REGISTER_OP("Correlation") +.Input("input_a: float32") +.Input("input_b: float32") +.Attr("kernel_size: int") +.Attr("max_displacement: int") +.Attr("stride_1: int") +.Attr("stride_2: int") +.Attr("pad: int") +.Output("output: float32") +.SetShapeFn(SetOutput); + +REGISTER_OP("CorrelationGrad") +.Input("gradients: float32") +.Input("input_a: float32") +.Input("input_b: float32") +.Attr("kernel_size: int") +.Attr("max_displacement: int") +.Attr("stride_1: int") +.Attr("stride_2: int") +.Attr("pad: int") +.Output("backprops_a: float32") +.Output("backprops_b: float32") +.SetShapeFn([](InferenceContext *c) { + // Output gradients should be the same dimensions as the inputs + ShapeHandle out; + TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->input(2), &out)); + c->set_output(0, out); + c->set_output(1, out); + return Status::OK(); + }); +} // namespace tensorflow diff --git a/Codes/flownet2/src/ops/correlation/pad.cu.cc b/Codes/flownet2/src/ops/correlation/pad.cu.cc new file mode 100644 index 0000000..0b6c93d --- /dev/null +++ b/Codes/flownet2/src/ops/correlation/pad.cu.cc @@ -0,0 +1,76 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include <stdio.h> +#include <iostream> + +#include "pad.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +__global__ void PadData( + const float *in, + int in_widthheight, + int in_width, + int in_height, + int out_width, + int out_height, + int channels, + int padding, + float *out) { + int xy = blockIdx.x * blockDim.x + threadIdx.x; + + int x = xy % in_width; + int y = xy / in_width; + int ch = blockIdx.y; + int n = blockIdx.z; + + if (xy >= in_widthheight) { + out[((n * out_height + y) * out_width + x) * channels + ch] = 0.0; + return; + } + + float value = in[((n * in_height + y) * in_width + x) * channels + ch]; + + __syncthreads(); + + int xpad = x + padding; + int ypad = y + padding; + + out[((n * out_height + ypad) * out_width + xpad) * channels + ch] = value; +} + +void Pad(const GPUDevice& device, + const float *input, + int batch_size, + int input_height, + int input_width, + int input_channels, + int output_height, + int output_width, + float *output) { + int in_widthheight = input_width * input_height; + int threads_per_block = 16; + dim3 totalBlocks((in_widthheight - 1) / threads_per_block + 1, input_channels, batch_size); + + cudaMemset(output, 0, batch_size * output_height * output_width * input_channels * sizeof(float)); + + int padding = (output_height - input_height) / 2; + + // LAUNCH KERNEL + PadData << < totalBlocks, threads_per_block, 0, device.stream() >> > ( + input, + in_widthheight, + input_width, + input_height, + output_width, + output_height, + input_channels, + padding, + output); +} +} +#endif // if GOOGLE_CUDA diff --git a/Codes/flownet2/src/ops/correlation/pad.h b/Codes/flownet2/src/ops/correlation/pad.h new file mode 100644 index 0000000..afb4df0 --- /dev/null +++ b/Codes/flownet2/src/ops/correlation/pad.h @@ -0,0 +1,20 @@ +#ifndef FLOWNET_PAD_H_ +#define FLOWNET_PAD_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +void Pad(const GPUDevice& device, + const float *input, + int batch_size, + int input_height, + int input_width, + int input_channels, + int output_height, + int output_width, + float *output); +} // end namespace tensorflow + +#endif // ifndef FLOWNET_PAD_H_ |
