diff options
Diffstat (limited to 'Codes/flownet2/src/ops/flow_warp/flow_warp.cc')
| -rw-r--r-- | Codes/flownet2/src/ops/flow_warp/flow_warp.cc | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp.cc b/Codes/flownet2/src/ops/flow_warp/flow_warp.cc new file mode 100644 index 0000000..b5d9602 --- /dev/null +++ b/Codes/flownet2/src/ops/flow_warp/flow_warp.cc @@ -0,0 +1,48 @@ +#define EIGEN_USE_THREADS + +#include "flow_warp.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +template<typename Device> +class FlowWarpKernel : public OpKernel { + public: + explicit FlowWarpKernel(OpKernelConstruction *ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext *ctx) override { + // Get the input image and flow and verify dimensions + const Tensor& input_t = ctx->input(0); + const Tensor& flow_t = ctx->input(1); + + OP_REQUIRES(ctx, input_t.dims() == 4, + errors::InvalidArgument("Input image must have rank 4")); + OP_REQUIRES(ctx, flow_t.dims() == 4, + errors::InvalidArgument("Input flow must have rank 4")); + OP_REQUIRES(ctx, + input_t.dim_size(0) == flow_t.dim_size(0) && input_t.dim_size( + 1) == flow_t.dim_size(1) && input_t.dim_size(2) == flow_t.dim_size(2), + errors::InvalidArgument( + "Input image and flow must have same N x H x W dimensions")); + + // Allocate the memory for the output + Tensor *output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input_t.shape(), &output_t)); + + // Perform flow augmentation + auto input = input_t.tensor<float, 4>(); + auto flow = flow_t.tensor<float, 4>(); + auto output = output_t->tensor<float, 4>(); + + FlowWarp(ctx->eigen_gpu_device(), input, flow, output); + } +}; + +REGISTER_KERNEL_BUILDER(Name("FlowWarp") + .Device(DEVICE_GPU), + FlowWarpKernel<GPUDevice>) +} // end namespace tensorflow |
