From fede6ca1dd0077ff509d84bd24028cc7a93bb119 Mon Sep 17 00:00:00 2001 From: StevenLiuWen Date: Tue, 13 Mar 2018 03:28:06 -0400 Subject: first commit --- Codes/flownet2/src/ops/downsample/downsample_op.cc | 30 ++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 Codes/flownet2/src/ops/downsample/downsample_op.cc (limited to 'Codes/flownet2/src/ops/downsample/downsample_op.cc') diff --git a/Codes/flownet2/src/ops/downsample/downsample_op.cc b/Codes/flownet2/src/ops/downsample/downsample_op.cc new file mode 100644 index 0000000..6980dc7 --- /dev/null +++ b/Codes/flownet2/src/ops/downsample/downsample_op.cc @@ -0,0 +1,30 @@ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; +using shape_inference::DimensionHandle; + +Status SetOutputToSizedImage(InferenceContext* c) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); + DimensionHandle batch = c->Dim(input, 0); + DimensionHandle depth = c->Dim(input, 3); + std::vector size_; + c->GetAttr("size", &size_); + DimensionHandle height = c->MakeDim(size_[0]); + DimensionHandle width = c->MakeDim(size_[1]); + c->set_output(0, c->MakeShape({batch, height, width, depth})); + return Status::OK(); +} + +REGISTER_OP("Downsample") + .Input("input: float32") + .Attr("size: list(int) >= 2") + .Output("output: float32") + .SetShapeFn(SetOutputToSizedImage); + +} // namespace tensorflow -- cgit v1.2.3-70-g09d2