blob: 6980dc786ce4076bb8b491121fe132eed46e78cd (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
|
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
using shape_inference::DimensionHandle;
Status SetOutputToSizedImage(InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
DimensionHandle batch = c->Dim(input, 0);
DimensionHandle depth = c->Dim(input, 3);
std::vector<int32> size_;
c->GetAttr("size", &size_);
DimensionHandle height = c->MakeDim(size_[0]);
DimensionHandle width = c->MakeDim(size_[1]);
c->set_output(0, c->MakeShape({batch, height, width, depth}));
return Status::OK();
}
REGISTER_OP("Downsample")
.Input("input: float32")
.Attr("size: list(int) >= 2")
.Output("output: float32")
.SetShapeFn(SetOutputToSizedImage);
} // namespace tensorflow
|