diff options
Diffstat (limited to 'Codes/flownet2/src/ops/flow_warp/flow_warp.h')
| -rw-r--r-- | Codes/flownet2/src/ops/flow_warp/flow_warp.h | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/flow_warp/flow_warp.h b/Codes/flownet2/src/ops/flow_warp/flow_warp.h new file mode 100644 index 0000000..2780316 --- /dev/null +++ b/Codes/flownet2/src/ops/flow_warp/flow_warp.h @@ -0,0 +1,28 @@ +#ifndef FLOWNET_FLOWWARP_H_ +#define FLOWNET_FLOWWARP_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +#define FW_THREADS 32 +#define FW_TILE_X FW_THREADS +#define FW_TILE_C FW_THREADS + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +void FlowWarp(const GPUDevice& device, + typename TTypes<float, 4>::ConstTensor input, + typename TTypes<float, 4>::ConstTensor flow, + typename TTypes<float, 4>::Tensor output); + +void FlowWarpGrad(const GPUDevice& device, + typename TTypes<float, 4>::ConstTensor image, + typename TTypes<float, 4>::ConstTensor flow, + typename TTypes<float, 4>::ConstTensor gradient, + typename TTypes<float, 4>::Tensor image_grad, + typename TTypes<float, 4>::Tensor flow_grad); +} // end namespace tensorflow + +#endif // FLOWNET_FLOWWARP_H_ |
