summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/ops/downsample/downsample_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/flownet2/src/ops/downsample/downsample_op.cc')
-rw-r--r--Codes/flownet2/src/ops/downsample/downsample_op.cc30
1 files changed, 30 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/downsample/downsample_op.cc b/Codes/flownet2/src/ops/downsample/downsample_op.cc
new file mode 100644
index 0000000..6980dc7
--- /dev/null
+++ b/Codes/flownet2/src/ops/downsample/downsample_op.cc
@@ -0,0 +1,30 @@
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::DimensionHandle;
+
+Status SetOutputToSizedImage(InferenceContext* c) {
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+ DimensionHandle batch = c->Dim(input, 0);
+ DimensionHandle depth = c->Dim(input, 3);
+ std::vector<int32> size_;
+ c->GetAttr("size", &size_);
+ DimensionHandle height = c->MakeDim(size_[0]);
+ DimensionHandle width = c->MakeDim(size_[1]);
+ c->set_output(0, c->MakeShape({batch, height, width, depth}));
+ return Status::OK();
+}
+
+REGISTER_OP("Downsample")
+ .Input("input: float32")
+ .Attr("size: list(int) >= 2")
+ .Output("output: float32")
+ .SetShapeFn(SetOutputToSizedImage);
+
+} // namespace tensorflow