diff options
Diffstat (limited to 'Codes/flownet2/src/ops/flow_warp')
| -rw-r--r-- | Codes/flownet2/src/ops/flow_warp/flow_warp.cc | 48 | ||||
| -rw-r--r-- | Codes/flownet2/src/ops/flow_warp/flow_warp.cu.cc | 130 | ||||
| -rw-r--r-- | Codes/flownet2/src/ops/flow_warp/flow_warp.h | 28 | ||||
| -rw-r--r-- | Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cc | 57 | ||||
| -rw-r--r-- | Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cu.cc | 126 | ||||
| -rw-r--r-- | Codes/flownet2/src/ops/flow_warp/flow_warp_op.cc | 23 |
6 files changed, 412 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp.cc b/Codes/flownet2/src/ops/flow_warp/flow_warp.cc new file mode 100644 index 0000000..b5d9602 --- /dev/null +++ b/Codes/flownet2/src/ops/flow_warp/flow_warp.cc @@ -0,0 +1,48 @@ +#define EIGEN_USE_THREADS + +#include "flow_warp.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +template<typename Device> +class FlowWarpKernel : public OpKernel { + public: + explicit FlowWarpKernel(OpKernelConstruction *ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext *ctx) override { + // Get the input image and flow and verify dimensions + const Tensor& input_t = ctx->input(0); + const Tensor& flow_t = ctx->input(1); + + OP_REQUIRES(ctx, input_t.dims() == 4, + errors::InvalidArgument("Input image must have rank 4")); + OP_REQUIRES(ctx, flow_t.dims() == 4, + errors::InvalidArgument("Input flow must have rank 4")); + OP_REQUIRES(ctx, + input_t.dim_size(0) == flow_t.dim_size(0) && input_t.dim_size( + 1) == flow_t.dim_size(1) && input_t.dim_size(2) == flow_t.dim_size(2), + errors::InvalidArgument( + "Input image and flow must have same N x H x W dimensions")); + + // Allocate the memory for the output + Tensor *output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input_t.shape(), &output_t)); + + // Perform flow augmentation + auto input = input_t.tensor<float, 4>(); + auto flow = flow_t.tensor<float, 4>(); + auto output = output_t->tensor<float, 4>(); + + FlowWarp(ctx->eigen_gpu_device(), input, flow, output); + } +}; + +REGISTER_KERNEL_BUILDER(Name("FlowWarp") + .Device(DEVICE_GPU), + FlowWarpKernel<GPUDevice>) +} // end namespace tensorflow diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp.cu.cc b/Codes/flownet2/src/ops/flow_warp/flow_warp.cu.cc new file mode 100644 index 0000000..2007151 --- /dev/null +++ b/Codes/flownet2/src/ops/flow_warp/flow_warp.cu.cc @@ -0,0 +1,130 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include <stdio.h> +#include <iostream> + +#include "flow_warp.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +#define RA_TILE 32 +#define RA_ROWS 8 + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +__global__ void FlowWarpKernel( + const float *image, + const float *flow, + float *warped, + const int batch_size, + const int channels, + const int cblocks, + const int width, + const int wblocks, + const int height, + const int width_height) { + int y = blockIdx.y; + int n = blockIdx.z; + + __shared__ float x2_buf[FW_TILE_X], y2_buf[FW_TILE_X]; + __shared__ float buffer[FW_TILE_C][FW_TILE_X + 1]; + + int x; + int c; + + x = blockIdx.x * FW_TILE_X + threadIdx.x; + + if ((threadIdx.y == 0) && (x < width)) { + const int idx = ((n * height + y) * width + x) * 2; + x2_buf[threadIdx.x] = float(x) + flow[idx]; + y2_buf[threadIdx.x] = float(y) + flow[idx + 1]; + } + + __syncthreads(); + + float x2 = x2_buf[threadIdx.y]; + float y2 = y2_buf[threadIdx.y]; + + int ix2_L = int(x2); + int iy2_T = int(y2); + int ix2_R = min(ix2_L + 1, width - 1); + int iy2_B = min(iy2_T + 1, height - 1); + + int off_TL = ((n * height + iy2_T) * width + ix2_L) * channels; + int off_TR = ((n * height + iy2_T) * width + ix2_R) * channels; + int off_BL = ((n * height + iy2_B) * width + ix2_L) * channels; + int off_BR = ((n * height + iy2_B) * width + ix2_R) * channels; + + float alpha = x2 - ix2_L; + float beta = y2 - iy2_T; + float coeffTL = (1 - alpha) * (1 - beta); + float coeffTR = alpha * (1 - beta); + float coeffBL = (1 - alpha) * beta; + float coeffBR = alpha * beta; + + for (int cb = 0; cb < cblocks; cb++) { + __syncthreads(); + + buffer[threadIdx.y][threadIdx.x] = 0.0; + + __syncthreads(); + + c = cb * FW_TILE_C + threadIdx.x; + + if ((x2 >= 0) && (y2 >= 0) && (x2 < width) && (y2 < height) && (c < channels)) { + buffer[threadIdx.y][threadIdx.x] = // buffer [x][c] + coeffTL * image[off_TL + c] + + coeffTR * image[off_TR + c] + + coeffBL * image[off_BL + c] + + coeffBR * image[off_BR + c]; + } + + __syncthreads(); + + c = cb * FW_TILE_C + threadIdx.y; + x = blockIdx.x * FW_TILE_X + threadIdx.x; + + if ((c < channels) && (x < width)) { + warped[((n * height + y) * width + x) * channels + c] = buffer[threadIdx.x][threadIdx.y]; + } + } +} + +void FlowWarp(const GPUDevice& device, + typename TTypes<float, 4>::ConstTensor input, + typename TTypes<float, 4>::ConstTensor flow, + typename TTypes<float, 4>::Tensor output) { + const int batch_size = input.dimension(0); + const int height = input.dimension(1); + const int width = input.dimension(2); + const int channels = input.dimension(3); + + const int width_height = width * height; + int wblocks = ((width - 1) / FW_TILE_X + 1); + int cblocks = ((channels - 1) / FW_TILE_C + 1); + dim3 warpThreads(FW_TILE_X, FW_TILE_C); + dim3 warpBlocks(wblocks, height, batch_size); + + cudaMemset(output.data(), 0, batch_size * height * width * 2 * sizeof(float)); + + FlowWarpKernel << < warpBlocks, warpThreads, 0, device.stream() >> > ( + input.data(), + flow.data(), + output.data(), + batch_size, + channels, + cblocks, + width, + wblocks, + height, + width_height); +} +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp.h b/Codes/flownet2/src/ops/flow_warp/flow_warp.h new file mode 100644 index 0000000..2780316 --- /dev/null +++ b/Codes/flownet2/src/ops/flow_warp/flow_warp.h @@ -0,0 +1,28 @@ +#ifndef FLOWNET_FLOWWARP_H_ +#define FLOWNET_FLOWWARP_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +#define FW_THREADS 32 +#define FW_TILE_X FW_THREADS +#define FW_TILE_C FW_THREADS + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +void FlowWarp(const GPUDevice& device, + typename TTypes<float, 4>::ConstTensor input, + typename TTypes<float, 4>::ConstTensor flow, + typename TTypes<float, 4>::Tensor output); + +void FlowWarpGrad(const GPUDevice& device, + typename TTypes<float, 4>::ConstTensor image, + typename TTypes<float, 4>::ConstTensor flow, + typename TTypes<float, 4>::ConstTensor gradient, + typename TTypes<float, 4>::Tensor image_grad, + typename TTypes<float, 4>::Tensor flow_grad); +} // end namespace tensorflow + +#endif // FLOWNET_FLOWWARP_H_ diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cc b/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cc new file mode 100644 index 0000000..9f3e7ea --- /dev/null +++ b/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cc @@ -0,0 +1,57 @@ +#define EIGEN_USE_THREADS + +#include "flow_warp.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +template<typename Device> +class FlowWarpGradKernel : public OpKernel { + public: + explicit FlowWarpGradKernel(OpKernelConstruction *ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext *ctx) override { + // Get the input image and flow and verify dimensions + const Tensor& image_t = ctx->input(0); + const Tensor& flow_t = ctx->input(1); + const Tensor& grad_t = ctx->input(2); + + OP_REQUIRES(ctx, image_t.dims() == 4, + errors::InvalidArgument("Input image must have rank 4")); + OP_REQUIRES(ctx, flow_t.dims() == 4, + errors::InvalidArgument("Input flow must have rank 4")); + OP_REQUIRES(ctx, + image_t.dim_size(0) == flow_t.dim_size(0) && image_t.dim_size( + 1) == flow_t.dim_size(1) && image_t.dim_size(2) == flow_t.dim_size(2), + errors::InvalidArgument( + "Input image and flow must have same N x H x W dimensions")); + + // Allocate the memory for the output + Tensor *image_grad_t; + Tensor *flow_grad_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, image_t.shape(), &image_grad_t)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, flow_t.shape(), &flow_grad_t)); + + auto image = image_t.tensor<float, 4>(); + auto flow = flow_t.tensor<float, 4>(); + auto gradient = grad_t.tensor<float, 4>(); + auto image_grad = image_grad_t->tensor<float, 4>(); + auto flow_grad = flow_grad_t->tensor<float, 4>(); + + FlowWarpGrad(ctx->eigen_gpu_device(), + image, + flow, + gradient, + image_grad, + flow_grad); + } +}; + +REGISTER_KERNEL_BUILDER(Name("FlowWarpGrad") + .Device(DEVICE_GPU), + FlowWarpGradKernel<GPUDevice>) +} // end namespace tensorflow diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cu.cc b/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cu.cc new file mode 100644 index 0000000..25248c8 --- /dev/null +++ b/Codes/flownet2/src/ops/flow_warp/flow_warp_grad.cu.cc @@ -0,0 +1,126 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "flow_warp.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +__global__ void FlowWarpGradKernel( + const float *image, + float *image_grad, + const float *flow, + float *flow_grad, + const float *gradient, + int batch_size, + int channels, + int cblocks, + int width, + int wblocks, + int height, + int widthheight) { + int x = blockIdx.x * FW_TILE_X + threadIdx.x; + + if (x >= width) return; + + int y = blockIdx.y; + int n = blockIdx.z; + + const int flow_idx = ((n * height + y) * width + x) * 2; + float x2 = float(x) + flow[flow_idx]; + float y2 = float(y) + flow[flow_idx + 1]; + + if ((x2 >= 0.f) && (y2 >= 0.f) && (x2 < width) && (y2 < height)) { + int ix2_L = int(x2); + int iy2_T = int(y2); + int ix2_R = min(ix2_L + 1, width - 1); + int iy2_B = min(iy2_T + 1, height - 1); + + float alpha = x2 - ix2_L; + float beta = y2 - iy2_T; + + for (int c = 0; c < channels; c++) { + float warped_diff_value = gradient[((n * height + y) * width + x) * channels + c]; + atomicAdd(&image_grad[((n * height + iy2_T) * width + ix2_L) * channels + c], + warped_diff_value * (1 - alpha) * (1 - beta)); + atomicAdd(&image_grad[((n * height + iy2_T) * width + ix2_R) * channels + c], + warped_diff_value * alpha * (1 - beta)); + atomicAdd(&image_grad[((n * height + iy2_B) * width + ix2_L) * channels + c], + warped_diff_value * (1 - alpha) * beta); + atomicAdd(&image_grad[((n * height + iy2_B) * width + ix2_R) * channels + c], + warped_diff_value * alpha * beta); + } + + float gamma = iy2_B - y2; + float bot_diff = 0; + + for (int c = 0; c < channels; c++) { + int ch_off = (n * channels + c) * height; + float temp = 0; + temp += gamma * + (image[((n * height + iy2_T) * width + ix2_R) * channels + c] - + image[((n * height + iy2_T) * width + ix2_L) * channels + c]); + temp += (1 - gamma) * + (image[((n * height + iy2_B) * width + ix2_R) * channels + c] - + image[((n * height + iy2_B) * width + ix2_L) * channels + c]); + + bot_diff += gradient[((n * height + y) * width + x) * channels + c] * temp; + } + flow_grad[((n * height + y) * width + x) * 2] = bot_diff; + + gamma = ix2_R - x2; + bot_diff = 0; + + for (int c = 0; c < channels; c++) { + float temp = 0; + temp += gamma * + (image[((n * height + iy2_B) * width + ix2_L) * channels + c] - + image[((n * height + iy2_T) * width + ix2_L) * channels + c]); + temp += (1 - gamma) * + (image[((n * height + iy2_B) * width + ix2_R) * channels + c] - + image[((n * height + iy2_T) * width + ix2_R) * channels + c]); + + bot_diff += gradient[((n * height + y) * width + x) * channels + c] * temp; + } + flow_grad[((n * height + y) * width + x) * 2 + 1] = bot_diff; + } +} + +void FlowWarpGrad(const GPUDevice& device, + typename TTypes<float, 4>::ConstTensor image, + typename TTypes<float, 4>::ConstTensor flow, + typename TTypes<float, 4>::ConstTensor gradient, + typename TTypes<float, 4>::Tensor image_grad, + typename TTypes<float, 4>::Tensor flow_grad) { + const int batch_size = image.dimension(0); + const int height = image.dimension(1); + const int width = image.dimension(2); + const int channels = image.dimension(3); + const int width_height = width * height; + + int wblocks = ((width - 1) / FW_TILE_X + 1); + int cblocks = ((channels - 1) / FW_TILE_C + 1); + dim3 warpThreads(FW_TILE_X, 1); + dim3 warpBlocks(wblocks, height, batch_size); + + cudaMemset(image_grad.data(), 0, batch_size * height * width * channels * sizeof(float)); + cudaMemset(flow_grad.data(), 0, batch_size * height * width * 2 * sizeof(float)); + + FlowWarpGradKernel << < warpBlocks, warpThreads, 0, device.stream() >> > ( + image.data(), + image_grad.data(), + flow.data(), + flow_grad.data(), + gradient.data(), + batch_size, + channels, + cblocks, + width, + wblocks, + height, + width_height); +} +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp_op.cc b/Codes/flownet2/src/ops/flow_warp/flow_warp_op.cc new file mode 100644 index 0000000..aef9c74 --- /dev/null +++ b/Codes/flownet2/src/ops/flow_warp/flow_warp_op.cc @@ -0,0 +1,23 @@ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +REGISTER_OP("FlowWarp") +.Input("image: float32") +.Input("flow: float32") +.Output("output: float32") +.SetShapeFn(::tensorflow::shape_inference::UnchangedShape); + +REGISTER_OP("FlowWarpGrad") +.Input("image: float32") +.Input("flow: float32") +.Input("gradient: float32") +.Output("image_grad: float32") +.Output("flow_grad: float32") +.SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->input(0)); + c->set_output(1, c->input(1)); + return Status::OK(); + }); +} // namespace tensorflow |
