summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/ops/flow_warp
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/flownet2/src/ops/flow_warp')
-rw-r--r--Codes/flownet2/src/ops/flow_warp/flow_warp.cc48
-rw-r--r--Codes/flownet2/src/ops/flow_warp/flow_warp.cu.cc130
-rw-r--r--Codes/flownet2/src/ops/flow_warp/flow_warp.h28
-rw-r--r--Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cc57
-rw-r--r--Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cu.cc126
-rw-r--r--Codes/flownet2/src/ops/flow_warp/flow_warp_op.cc23
6 files changed, 412 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp.cc b/Codes/flownet2/src/ops/flow_warp/flow_warp.cc
new file mode 100644
index 0000000..b5d9602
--- /dev/null
+++ b/Codes/flownet2/src/ops/flow_warp/flow_warp.cc
@@ -0,0 +1,48 @@
+#define EIGEN_USE_THREADS
+
+#include "flow_warp.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+template<typename Device>
+class FlowWarpKernel : public OpKernel {
+ public:
+ explicit FlowWarpKernel(OpKernelConstruction *ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext *ctx) override {
+ // Get the input image and flow and verify dimensions
+ const Tensor& input_t = ctx->input(0);
+ const Tensor& flow_t = ctx->input(1);
+
+ OP_REQUIRES(ctx, input_t.dims() == 4,
+ errors::InvalidArgument("Input image must have rank 4"));
+ OP_REQUIRES(ctx, flow_t.dims() == 4,
+ errors::InvalidArgument("Input flow must have rank 4"));
+ OP_REQUIRES(ctx,
+ input_t.dim_size(0) == flow_t.dim_size(0) && input_t.dim_size(
+ 1) == flow_t.dim_size(1) && input_t.dim_size(2) == flow_t.dim_size(2),
+ errors::InvalidArgument(
+ "Input image and flow must have same N x H x W dimensions"));
+
+ // Allocate the memory for the output
+ Tensor *output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input_t.shape(), &output_t));
+
+ // Perform flow augmentation
+ auto input = input_t.tensor<float, 4>();
+ auto flow = flow_t.tensor<float, 4>();
+ auto output = output_t->tensor<float, 4>();
+
+ FlowWarp(ctx->eigen_gpu_device(), input, flow, output);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("FlowWarp")
+ .Device(DEVICE_GPU),
+ FlowWarpKernel<GPUDevice>)
+} // end namespace tensorflow
diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp.cu.cc b/Codes/flownet2/src/ops/flow_warp/flow_warp.cu.cc
new file mode 100644
index 0000000..2007151
--- /dev/null
+++ b/Codes/flownet2/src/ops/flow_warp/flow_warp.cu.cc
@@ -0,0 +1,130 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+#include <iostream>
+
+#include "flow_warp.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+#define RA_TILE 32
+#define RA_ROWS 8
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+__global__ void FlowWarpKernel(
+ const float *image,
+ const float *flow,
+ float *warped,
+ const int batch_size,
+ const int channels,
+ const int cblocks,
+ const int width,
+ const int wblocks,
+ const int height,
+ const int width_height) {
+ int y = blockIdx.y;
+ int n = blockIdx.z;
+
+ __shared__ float x2_buf[FW_TILE_X], y2_buf[FW_TILE_X];
+ __shared__ float buffer[FW_TILE_C][FW_TILE_X + 1];
+
+ int x;
+ int c;
+
+ x = blockIdx.x * FW_TILE_X + threadIdx.x;
+
+ if ((threadIdx.y == 0) && (x < width)) {
+ const int idx = ((n * height + y) * width + x) * 2;
+ x2_buf[threadIdx.x] = float(x) + flow[idx];
+ y2_buf[threadIdx.x] = float(y) + flow[idx + 1];
+ }
+
+ __syncthreads();
+
+ float x2 = x2_buf[threadIdx.y];
+ float y2 = y2_buf[threadIdx.y];
+
+ int ix2_L = int(x2);
+ int iy2_T = int(y2);
+ int ix2_R = min(ix2_L + 1, width - 1);
+ int iy2_B = min(iy2_T + 1, height - 1);
+
+ int off_TL = ((n * height + iy2_T) * width + ix2_L) * channels;
+ int off_TR = ((n * height + iy2_T) * width + ix2_R) * channels;
+ int off_BL = ((n * height + iy2_B) * width + ix2_L) * channels;
+ int off_BR = ((n * height + iy2_B) * width + ix2_R) * channels;
+
+ float alpha = x2 - ix2_L;
+ float beta = y2 - iy2_T;
+ float coeffTL = (1 - alpha) * (1 - beta);
+ float coeffTR = alpha * (1 - beta);
+ float coeffBL = (1 - alpha) * beta;
+ float coeffBR = alpha * beta;
+
+ for (int cb = 0; cb < cblocks; cb++) {
+ __syncthreads();
+
+ buffer[threadIdx.y][threadIdx.x] = 0.0;
+
+ __syncthreads();
+
+ c = cb * FW_TILE_C + threadIdx.x;
+
+ if ((x2 >= 0) && (y2 >= 0) && (x2 < width) && (y2 < height) && (c < channels)) {
+ buffer[threadIdx.y][threadIdx.x] = // buffer [x][c]
+ coeffTL * image[off_TL + c] +
+ coeffTR * image[off_TR + c] +
+ coeffBL * image[off_BL + c] +
+ coeffBR * image[off_BR + c];
+ }
+
+ __syncthreads();
+
+ c = cb * FW_TILE_C + threadIdx.y;
+ x = blockIdx.x * FW_TILE_X + threadIdx.x;
+
+ if ((c < channels) && (x < width)) {
+ warped[((n * height + y) * width + x) * channels + c] = buffer[threadIdx.x][threadIdx.y];
+ }
+ }
+}
+
+void FlowWarp(const GPUDevice& device,
+ typename TTypes<float, 4>::ConstTensor input,
+ typename TTypes<float, 4>::ConstTensor flow,
+ typename TTypes<float, 4>::Tensor output) {
+ const int batch_size = input.dimension(0);
+ const int height = input.dimension(1);
+ const int width = input.dimension(2);
+ const int channels = input.dimension(3);
+
+ const int width_height = width * height;
+ int wblocks = ((width - 1) / FW_TILE_X + 1);
+ int cblocks = ((channels - 1) / FW_TILE_C + 1);
+ dim3 warpThreads(FW_TILE_X, FW_TILE_C);
+ dim3 warpBlocks(wblocks, height, batch_size);
+
+ cudaMemset(output.data(), 0, batch_size * height * width * 2 * sizeof(float));
+
+ FlowWarpKernel << < warpBlocks, warpThreads, 0, device.stream() >> > (
+ input.data(),
+ flow.data(),
+ output.data(),
+ batch_size,
+ channels,
+ cblocks,
+ width,
+ wblocks,
+ height,
+ width_height);
+}
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp.h b/Codes/flownet2/src/ops/flow_warp/flow_warp.h
new file mode 100644
index 0000000..2780316
--- /dev/null
+++ b/Codes/flownet2/src/ops/flow_warp/flow_warp.h
@@ -0,0 +1,28 @@
+#ifndef FLOWNET_FLOWWARP_H_
+#define FLOWNET_FLOWWARP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+
+#define FW_THREADS 32
+#define FW_TILE_X FW_THREADS
+#define FW_TILE_C FW_THREADS
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+void FlowWarp(const GPUDevice& device,
+ typename TTypes<float, 4>::ConstTensor input,
+ typename TTypes<float, 4>::ConstTensor flow,
+ typename TTypes<float, 4>::Tensor output);
+
+void FlowWarpGrad(const GPUDevice& device,
+ typename TTypes<float, 4>::ConstTensor image,
+ typename TTypes<float, 4>::ConstTensor flow,
+ typename TTypes<float, 4>::ConstTensor gradient,
+ typename TTypes<float, 4>::Tensor image_grad,
+ typename TTypes<float, 4>::Tensor flow_grad);
+} // end namespace tensorflow
+
+#endif // FLOWNET_FLOWWARP_H_
diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cc b/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cc
new file mode 100644
index 0000000..9f3e7ea
--- /dev/null
+++ b/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cc
@@ -0,0 +1,57 @@
+#define EIGEN_USE_THREADS
+
+#include "flow_warp.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+template<typename Device>
+class FlowWarpGradKernel : public OpKernel {
+ public:
+ explicit FlowWarpGradKernel(OpKernelConstruction *ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext *ctx) override {
+ // Get the input image and flow and verify dimensions
+ const Tensor& image_t = ctx->input(0);
+ const Tensor& flow_t = ctx->input(1);
+ const Tensor& grad_t = ctx->input(2);
+
+ OP_REQUIRES(ctx, image_t.dims() == 4,
+ errors::InvalidArgument("Input image must have rank 4"));
+ OP_REQUIRES(ctx, flow_t.dims() == 4,
+ errors::InvalidArgument("Input flow must have rank 4"));
+ OP_REQUIRES(ctx,
+ image_t.dim_size(0) == flow_t.dim_size(0) && image_t.dim_size(
+ 1) == flow_t.dim_size(1) && image_t.dim_size(2) == flow_t.dim_size(2),
+ errors::InvalidArgument(
+ "Input image and flow must have same N x H x W dimensions"));
+
+ // Allocate the memory for the output
+ Tensor *image_grad_t;
+ Tensor *flow_grad_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, image_t.shape(), &image_grad_t));
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, flow_t.shape(), &flow_grad_t));
+
+ auto image = image_t.tensor<float, 4>();
+ auto flow = flow_t.tensor<float, 4>();
+ auto gradient = grad_t.tensor<float, 4>();
+ auto image_grad = image_grad_t->tensor<float, 4>();
+ auto flow_grad = flow_grad_t->tensor<float, 4>();
+
+ FlowWarpGrad(ctx->eigen_gpu_device(),
+ image,
+ flow,
+ gradient,
+ image_grad,
+ flow_grad);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("FlowWarpGrad")
+ .Device(DEVICE_GPU),
+ FlowWarpGradKernel<GPUDevice>)
+} // end namespace tensorflow
diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cu.cc b/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cu.cc
new file mode 100644
index 0000000..25248c8
--- /dev/null
+++ b/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cu.cc
@@ -0,0 +1,126 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "flow_warp.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+__global__ void FlowWarpGradKernel(
+ const float *image,
+ float *image_grad,
+ const float *flow,
+ float *flow_grad,
+ const float *gradient,
+ int batch_size,
+ int channels,
+ int cblocks,
+ int width,
+ int wblocks,
+ int height,
+ int widthheight) {
+ int x = blockIdx.x * FW_TILE_X + threadIdx.x;
+
+ if (x >= width) return;
+
+ int y = blockIdx.y;
+ int n = blockIdx.z;
+
+ const int flow_idx = ((n * height + y) * width + x) * 2;
+ float x2 = float(x) + flow[flow_idx];
+ float y2 = float(y) + flow[flow_idx + 1];
+
+ if ((x2 >= 0.f) && (y2 >= 0.f) && (x2 < width) && (y2 < height)) {
+ int ix2_L = int(x2);
+ int iy2_T = int(y2);
+ int ix2_R = min(ix2_L + 1, width - 1);
+ int iy2_B = min(iy2_T + 1, height - 1);
+
+ float alpha = x2 - ix2_L;
+ float beta = y2 - iy2_T;
+
+ for (int c = 0; c < channels; c++) {
+ float warped_diff_value = gradient[((n * height + y) * width + x) * channels + c];
+ atomicAdd(&image_grad[((n * height + iy2_T) * width + ix2_L) * channels + c],
+ warped_diff_value * (1 - alpha) * (1 - beta));
+ atomicAdd(&image_grad[((n * height + iy2_T) * width + ix2_R) * channels + c],
+ warped_diff_value * alpha * (1 - beta));
+ atomicAdd(&image_grad[((n * height + iy2_B) * width + ix2_L) * channels + c],
+ warped_diff_value * (1 - alpha) * beta);
+ atomicAdd(&image_grad[((n * height + iy2_B) * width + ix2_R) * channels + c],
+ warped_diff_value * alpha * beta);
+ }
+
+ float gamma = iy2_B - y2;
+ float bot_diff = 0;
+
+ for (int c = 0; c < channels; c++) {
+ int ch_off = (n * channels + c) * height;
+ float temp = 0;
+ temp += gamma *
+ (image[((n * height + iy2_T) * width + ix2_R) * channels + c] -
+ image[((n * height + iy2_T) * width + ix2_L) * channels + c]);
+ temp += (1 - gamma) *
+ (image[((n * height + iy2_B) * width + ix2_R) * channels + c] -
+ image[((n * height + iy2_B) * width + ix2_L) * channels + c]);
+
+ bot_diff += gradient[((n * height + y) * width + x) * channels + c] * temp;
+ }
+ flow_grad[((n * height + y) * width + x) * 2] = bot_diff;
+
+ gamma = ix2_R - x2;
+ bot_diff = 0;
+
+ for (int c = 0; c < channels; c++) {
+ float temp = 0;
+ temp += gamma *
+ (image[((n * height + iy2_B) * width + ix2_L) * channels + c] -
+ image[((n * height + iy2_T) * width + ix2_L) * channels + c]);
+ temp += (1 - gamma) *
+ (image[((n * height + iy2_B) * width + ix2_R) * channels + c] -
+ image[((n * height + iy2_T) * width + ix2_R) * channels + c]);
+
+ bot_diff += gradient[((n * height + y) * width + x) * channels + c] * temp;
+ }
+ flow_grad[((n * height + y) * width + x) * 2 + 1] = bot_diff;
+ }
+}
+
+void FlowWarpGrad(const GPUDevice& device,
+ typename TTypes<float, 4>::ConstTensor image,
+ typename TTypes<float, 4>::ConstTensor flow,
+ typename TTypes<float, 4>::ConstTensor gradient,
+ typename TTypes<float, 4>::Tensor image_grad,
+ typename TTypes<float, 4>::Tensor flow_grad) {
+ const int batch_size = image.dimension(0);
+ const int height = image.dimension(1);
+ const int width = image.dimension(2);
+ const int channels = image.dimension(3);
+ const int width_height = width * height;
+
+ int wblocks = ((width - 1) / FW_TILE_X + 1);
+ int cblocks = ((channels - 1) / FW_TILE_C + 1);
+ dim3 warpThreads(FW_TILE_X, 1);
+ dim3 warpBlocks(wblocks, height, batch_size);
+
+ cudaMemset(image_grad.data(), 0, batch_size * height * width * channels * sizeof(float));
+ cudaMemset(flow_grad.data(), 0, batch_size * height * width * 2 * sizeof(float));
+
+ FlowWarpGradKernel << < warpBlocks, warpThreads, 0, device.stream() >> > (
+ image.data(),
+ image_grad.data(),
+ flow.data(),
+ flow_grad.data(),
+ gradient.data(),
+ batch_size,
+ channels,
+ cblocks,
+ width,
+ wblocks,
+ height,
+ width_height);
+}
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp_op.cc b/Codes/flownet2/src/ops/flow_warp/flow_warp_op.cc
new file mode 100644
index 0000000..aef9c74
--- /dev/null
+++ b/Codes/flownet2/src/ops/flow_warp/flow_warp_op.cc
@@ -0,0 +1,23 @@
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+REGISTER_OP("FlowWarp")
+.Input("image: float32")
+.Input("flow: float32")
+.Output("output: float32")
+.SetShapeFn(::tensorflow::shape_inference::UnchangedShape);
+
+REGISTER_OP("FlowWarpGrad")
+.Input("image: float32")
+.Input("flow: float32")
+.Input("gradient: float32")
+.Output("image_grad: float32")
+.Output("flow_grad: float32")
+.SetShapeFn([](shape_inference::InferenceContext *c) {
+ c->set_output(0, c->input(0));
+ c->set_output(1, c->input(1));
+ return Status::OK();
+ });
+} // namespace tensorflow