summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/ops/downsample/downsample_op.cc
blob: 6980dc786ce4076bb8b491121fe132eed46e78cd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

namespace tensorflow {

using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
using shape_inference::DimensionHandle;

Status SetOutputToSizedImage(InferenceContext* c) {
  ShapeHandle input;
  TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
  DimensionHandle batch = c->Dim(input, 0);
  DimensionHandle depth = c->Dim(input, 3);
  std::vector<int32> size_;
  c->GetAttr("size", &size_);
  DimensionHandle height = c->MakeDim(size_[0]);
  DimensionHandle width  = c->MakeDim(size_[1]);
  c->set_output(0, c->MakeShape({batch, height, width, depth}));
  return Status::OK();
}

REGISTER_OP("Downsample")
    .Input("input: float32")
    .Attr("size: list(int) >= 2")
    .Output("output: float32")
    .SetShapeFn(SetOutputToSizedImage);

}  // namespace tensorflow