diff options
Diffstat (limited to 'Codes/flownet2/src/ops/downsample')
4 files changed, 203 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/downsample/downsample_kernel.cc b/Codes/flownet2/src/ops/downsample/downsample_kernel.cc new file mode 100644 index 0000000..eefe247 --- /dev/null +++ b/Codes/flownet2/src/ops/downsample/downsample_kernel.cc @@ -0,0 +1,47 @@ +#define EIGEN_USE_THREADS + +#include "downsample_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +template <typename Device> +class DownsampleKernel : public OpKernel { + public: + explicit DownsampleKernel(OpKernelConstruction* ctx) : OpKernel(ctx) { + // Get the size [height, width] tensor and verify its dimensions + OP_REQUIRES_OK(ctx, ctx->GetAttr("size", &size_)); + OP_REQUIRES(ctx, size_.size() == 2, errors::InvalidArgument("size must be 2 dimensions")); + } + + void Compute(OpKernelContext* ctx) override { + // Get the input images and transforms and verify their dimensions + const Tensor& input_t = ctx->input(0); + OP_REQUIRES(ctx, input_t.dims() == 4, + errors::InvalidArgument("Input images must have rank 4")); + + // Allocate the memory for the output + Tensor* output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 0, TensorShape({input_t.dim_size(0), size_[0], size_[1], input_t.dim_size(3)}), &output_t)); + + // Perform flow augmentation + auto input = input_t.tensor<float, 4>(); + auto output = output_t->tensor<float, 4>(); + + Downsample(ctx->eigen_gpu_device(), input, output); + } + + private: + std::vector<int32> size_; +}; + +REGISTER_KERNEL_BUILDER(Name("Downsample") + .Device(DEVICE_GPU), + DownsampleKernel<GPUDevice>) +} // end namespace tensorflow diff --git a/Codes/flownet2/src/ops/downsample/downsample_kernel.h b/Codes/flownet2/src/ops/downsample/downsample_kernel.h new file mode 100644 index 0000000..bcc4e3f --- /dev/null +++ b/Codes/flownet2/src/ops/downsample/downsample_kernel.h @@ -0,0 +1,18 @@ +#ifndef FLOWNET_DOWNSAMPLE_H_ +#define FLOWNET_DOWNSAMPLE_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +bool Downsample(const GPUDevice& device, + typename TTypes<float, 4>::ConstTensor input, + typename TTypes<float, 4>::Tensor output); + +} // end namespace tensorflow + +#endif // FLOWNET_DOWNSAMPLE_H_ diff --git a/Codes/flownet2/src/ops/downsample/downsample_kernel_gpu.cu.cc b/Codes/flownet2/src/ops/downsample/downsample_kernel_gpu.cu.cc new file mode 100644 index 0000000..b7629a0 --- /dev/null +++ b/Codes/flownet2/src/ops/downsample/downsample_kernel_gpu.cu.cc @@ -0,0 +1,108 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include <stdio.h> +#include <iostream> + +#include "downsample_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +#define CUDART_NAN_F __int_as_float(0x7fffffff) + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +__global__ void DownsampleKernel( + const int32 nthreads, + const float* input_ptr, + float* output_ptr, + const int in_width, + const int in_height, + const int out_width, + const int out_height, + const int channels, + const float width_scale, + const float height_scale, + const int wradius, + const int hradius) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const int c = index % channels; + const int destx = (index / channels) % out_width; + const int desty = (index / channels / out_width) % out_height; + const int n = (index / channels / out_width) / out_height; + + const float srcx = ((float)destx / (float)(out_width - 1)) * (float)(in_width - 1); + const float srcy = ((float)desty / (float)(out_height - 1)) * (float)(in_height - 1); + + const int isrcx = round(srcx); + const int isrcy = round(srcy); + + float accum_value = 0; + float accum_weight = 0; + float accum_nan = 0; + + for (int dy = -hradius; dy <= hradius; dy++) { + int yoff = isrcy + dy; + // + for (int dx = -wradius; dx <= wradius; dx++) { + int xoff = isrcx + dx; + + if (xoff >= 0 && yoff >= 0 && xoff < in_width && yoff < in_height) { + int idx = ((n * in_height + yoff) * in_width + xoff) * channels + c; + float sample = input_ptr[idx]; + float weight = fmaxf(0.0f, 1.0f - (fabsf((float)xoff - srcx) / width_scale)) + * fmaxf(0.0f, 1.0f - (fabsf((float)yoff - srcy) / height_scale)); + if (sample != sample) { // isnan + accum_nan += weight; + sample = 0; + weight = 0; + } + accum_value += sample * weight; + accum_weight += weight; + } + } + } + + if (accum_nan / accum_weight > 0.5) { + output_ptr[index] = CUDART_NAN_F; + } else { + output_ptr[index] = accum_value / accum_weight; + } + } +} + +bool Downsample(const GPUDevice& device, + typename TTypes<float, 4>::ConstTensor input, + typename TTypes<float, 4>::Tensor output) { + const int batch_size = output.dimension(0); + const int out_height = output.dimension(1); + const int out_width = output.dimension(2); + const int out_channels = output.dimension(3); + const int total_count = batch_size * out_height * out_width * out_channels; + + const int in_height = input.dimension(1); + const int in_width = input.dimension(2); + + const float width_scale = (float)(in_width - 1) / (float)(out_width - 1); + const float height_scale = (float)(in_height - 1) / (float)(out_height - 1); + + const int wradius = ceil(width_scale); + const int hradius = ceil(height_scale); + + CudaLaunchConfig config = GetCudaLaunchConfig(total_count, device); + DownsampleKernel<<<config.block_count, config.thread_per_block, 0, + device.stream()>>>(total_count, input.data(), output.data(), + in_width, in_height, out_width, out_height, out_channels, + width_scale, height_scale, wradius, hradius); + return device.ok(); +} + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/Codes/flownet2/src/ops/downsample/downsample_op.cc b/Codes/flownet2/src/ops/downsample/downsample_op.cc new file mode 100644 index 0000000..6980dc7 --- /dev/null +++ b/Codes/flownet2/src/ops/downsample/downsample_op.cc @@ -0,0 +1,30 @@ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; +using shape_inference::DimensionHandle; + +Status SetOutputToSizedImage(InferenceContext* c) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); + DimensionHandle batch = c->Dim(input, 0); + DimensionHandle depth = c->Dim(input, 3); + std::vector<int32> size_; + c->GetAttr("size", &size_); + DimensionHandle height = c->MakeDim(size_[0]); + DimensionHandle width = c->MakeDim(size_[1]); + c->set_output(0, c->MakeShape({batch, height, width, depth})); + return Status::OK(); +} + +REGISTER_OP("Downsample") + .Input("input: float32") + .Attr("size: list(int) >= 2") + .Output("output: float32") + .SetShapeFn(SetOutputToSizedImage); + +} // namespace tensorflow |
