summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/ops/downsample
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/flownet2/src/ops/downsample')
-rw-r--r--Codes/flownet2/src/ops/downsample/downsample_kernel.cc47
-rw-r--r--Codes/flownet2/src/ops/downsample/downsample_kernel.h18
-rw-r--r--Codes/flownet2/src/ops/downsample/downsample_kernel_gpu.cu.cc108
-rw-r--r--Codes/flownet2/src/ops/downsample/downsample_op.cc30
4 files changed, 203 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/downsample/downsample_kernel.cc b/Codes/flownet2/src/ops/downsample/downsample_kernel.cc
new file mode 100644
index 0000000..eefe247
--- /dev/null
+++ b/Codes/flownet2/src/ops/downsample/downsample_kernel.cc
@@ -0,0 +1,47 @@
+#define EIGEN_USE_THREADS
+
+#include "downsample_kernel.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device>
+class DownsampleKernel : public OpKernel {
+ public:
+ explicit DownsampleKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ // Get the size [height, width] tensor and verify its dimensions
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("size", &size_));
+ OP_REQUIRES(ctx, size_.size() == 2, errors::InvalidArgument("size must be 2 dimensions"));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ // Get the input images and transforms and verify their dimensions
+ const Tensor& input_t = ctx->input(0);
+ OP_REQUIRES(ctx, input_t.dims() == 4,
+ errors::InvalidArgument("Input images must have rank 4"));
+
+ // Allocate the memory for the output
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(
+ 0, TensorShape({input_t.dim_size(0), size_[0], size_[1], input_t.dim_size(3)}), &output_t));
+
+ // Perform flow augmentation
+ auto input = input_t.tensor<float, 4>();
+ auto output = output_t->tensor<float, 4>();
+
+ Downsample(ctx->eigen_gpu_device(), input, output);
+ }
+
+ private:
+ std::vector<int32> size_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("Downsample")
+ .Device(DEVICE_GPU),
+ DownsampleKernel<GPUDevice>)
+} // end namespace tensorflow
diff --git a/Codes/flownet2/src/ops/downsample/downsample_kernel.h b/Codes/flownet2/src/ops/downsample/downsample_kernel.h
new file mode 100644
index 0000000..bcc4e3f
--- /dev/null
+++ b/Codes/flownet2/src/ops/downsample/downsample_kernel.h
@@ -0,0 +1,18 @@
+#ifndef FLOWNET_DOWNSAMPLE_H_
+#define FLOWNET_DOWNSAMPLE_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+bool Downsample(const GPUDevice& device,
+ typename TTypes<float, 4>::ConstTensor input,
+ typename TTypes<float, 4>::Tensor output);
+
+} // end namespace tensorflow
+
+#endif // FLOWNET_DOWNSAMPLE_H_
diff --git a/Codes/flownet2/src/ops/downsample/downsample_kernel_gpu.cu.cc b/Codes/flownet2/src/ops/downsample/downsample_kernel_gpu.cu.cc
new file mode 100644
index 0000000..b7629a0
--- /dev/null
+++ b/Codes/flownet2/src/ops/downsample/downsample_kernel_gpu.cu.cc
@@ -0,0 +1,108 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+#include <iostream>
+
+#include "downsample_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+#define CUDART_NAN_F __int_as_float(0x7fffffff)
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+__global__ void DownsampleKernel(
+ const int32 nthreads,
+ const float* input_ptr,
+ float* output_ptr,
+ const int in_width,
+ const int in_height,
+ const int out_width,
+ const int out_height,
+ const int channels,
+ const float width_scale,
+ const float height_scale,
+ const int wradius,
+ const int hradius) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ const int c = index % channels;
+ const int destx = (index / channels) % out_width;
+ const int desty = (index / channels / out_width) % out_height;
+ const int n = (index / channels / out_width) / out_height;
+
+ const float srcx = ((float)destx / (float)(out_width - 1)) * (float)(in_width - 1);
+ const float srcy = ((float)desty / (float)(out_height - 1)) * (float)(in_height - 1);
+
+ const int isrcx = round(srcx);
+ const int isrcy = round(srcy);
+
+ float accum_value = 0;
+ float accum_weight = 0;
+ float accum_nan = 0;
+
+ for (int dy = -hradius; dy <= hradius; dy++) {
+ int yoff = isrcy + dy;
+ //
+ for (int dx = -wradius; dx <= wradius; dx++) {
+ int xoff = isrcx + dx;
+
+ if (xoff >= 0 && yoff >= 0 && xoff < in_width && yoff < in_height) {
+ int idx = ((n * in_height + yoff) * in_width + xoff) * channels + c;
+ float sample = input_ptr[idx];
+ float weight = fmaxf(0.0f, 1.0f - (fabsf((float)xoff - srcx) / width_scale))
+ * fmaxf(0.0f, 1.0f - (fabsf((float)yoff - srcy) / height_scale));
+ if (sample != sample) { // isnan
+ accum_nan += weight;
+ sample = 0;
+ weight = 0;
+ }
+ accum_value += sample * weight;
+ accum_weight += weight;
+ }
+ }
+ }
+
+ if (accum_nan / accum_weight > 0.5) {
+ output_ptr[index] = CUDART_NAN_F;
+ } else {
+ output_ptr[index] = accum_value / accum_weight;
+ }
+ }
+}
+
+bool Downsample(const GPUDevice& device,
+ typename TTypes<float, 4>::ConstTensor input,
+ typename TTypes<float, 4>::Tensor output) {
+ const int batch_size = output.dimension(0);
+ const int out_height = output.dimension(1);
+ const int out_width = output.dimension(2);
+ const int out_channels = output.dimension(3);
+ const int total_count = batch_size * out_height * out_width * out_channels;
+
+ const int in_height = input.dimension(1);
+ const int in_width = input.dimension(2);
+
+ const float width_scale = (float)(in_width - 1) / (float)(out_width - 1);
+ const float height_scale = (float)(in_height - 1) / (float)(out_height - 1);
+
+ const int wradius = ceil(width_scale);
+ const int hradius = ceil(height_scale);
+
+ CudaLaunchConfig config = GetCudaLaunchConfig(total_count, device);
+ DownsampleKernel<<<config.block_count, config.thread_per_block, 0,
+ device.stream()>>>(total_count, input.data(), output.data(),
+ in_width, in_height, out_width, out_height, out_channels,
+ width_scale, height_scale, wradius, hradius);
+ return device.ok();
+}
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/Codes/flownet2/src/ops/downsample/downsample_op.cc b/Codes/flownet2/src/ops/downsample/downsample_op.cc
new file mode 100644
index 0000000..6980dc7
--- /dev/null
+++ b/Codes/flownet2/src/ops/downsample/downsample_op.cc
@@ -0,0 +1,30 @@
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::DimensionHandle;
+
+Status SetOutputToSizedImage(InferenceContext* c) {
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+ DimensionHandle batch = c->Dim(input, 0);
+ DimensionHandle depth = c->Dim(input, 3);
+ std::vector<int32> size_;
+ c->GetAttr("size", &size_);
+ DimensionHandle height = c->MakeDim(size_[0]);
+ DimensionHandle width = c->MakeDim(size_[1]);
+ c->set_output(0, c->MakeShape({batch, height, width, depth}));
+ return Status::OK();
+}
+
+REGISTER_OP("Downsample")
+ .Input("input: float32")
+ .Attr("size: list(int) >= 2")
+ .Output("output: float32")
+ .SetShapeFn(SetOutputToSizedImage);
+
+} // namespace tensorflow