blob: fe5fd4d5f929194d2db0c1533e4176b3bf2f5bcd (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
import tensorflow as tf
_flow_warp_ops = tf.load_op_library(
tf.resource_loader.get_path_to_datafile("./ops/build/flow_warp.so"))
def flow_warp(image, flow):
return _flow_warp_ops.flow_warp(image, flow)
@tf.RegisterGradient("FlowWarp")
def _flow_warp_grad(flow_warp_op, gradients):
return _flow_warp_ops.flow_warp_grad(flow_warp_op.inputs[0],
flow_warp_op.inputs[1],
gradients)
|