From fede6ca1dd0077ff509d84bd24028cc7a93bb119 Mon Sep 17 00:00:00 2001 From: StevenLiuWen Date: Tue, 13 Mar 2018 03:28:06 -0400 Subject: first commit --- .../src/ops/downsample/downsample_kernel.cc | 47 ++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 Codes/flownet2/src/ops/downsample/downsample_kernel.cc (limited to 'Codes/flownet2/src/ops/downsample/downsample_kernel.cc') diff --git a/Codes/flownet2/src/ops/downsample/downsample_kernel.cc b/Codes/flownet2/src/ops/downsample/downsample_kernel.cc new file mode 100644 index 0000000..eefe247 --- /dev/null +++ b/Codes/flownet2/src/ops/downsample/downsample_kernel.cc @@ -0,0 +1,47 @@ +#define EIGEN_USE_THREADS + +#include "downsample_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +template +class DownsampleKernel : public OpKernel { + public: + explicit DownsampleKernel(OpKernelConstruction* ctx) : OpKernel(ctx) { + // Get the size [height, width] tensor and verify its dimensions + OP_REQUIRES_OK(ctx, ctx->GetAttr("size", &size_)); + OP_REQUIRES(ctx, size_.size() == 2, errors::InvalidArgument("size must be 2 dimensions")); + } + + void Compute(OpKernelContext* ctx) override { + // Get the input images and transforms and verify their dimensions + const Tensor& input_t = ctx->input(0); + OP_REQUIRES(ctx, input_t.dims() == 4, + errors::InvalidArgument("Input images must have rank 4")); + + // Allocate the memory for the output + Tensor* output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 0, TensorShape({input_t.dim_size(0), size_[0], size_[1], input_t.dim_size(3)}), &output_t)); + + // Perform flow augmentation + auto input = input_t.tensor(); + auto output = output_t->tensor(); + + Downsample(ctx->eigen_gpu_device(), input, output); + } + + private: + std::vector size_; +}; + +REGISTER_KERNEL_BUILDER(Name("Downsample") + .Device(DEVICE_GPU), + DownsampleKernel) +} // end namespace tensorflow -- cgit v1.2.3-70-g09d2