summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/flow_warp.py
blob: fe5fd4d5f929194d2db0c1533e4176b3bf2f5bcd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import tensorflow as tf

_flow_warp_ops = tf.load_op_library(
    tf.resource_loader.get_path_to_datafile("./ops/build/flow_warp.so"))


def flow_warp(image, flow):
    return _flow_warp_ops.flow_warp(image, flow)


@tf.RegisterGradient("FlowWarp")
def _flow_warp_grad(flow_warp_op, gradients):
    return _flow_warp_ops.flow_warp_grad(flow_warp_op.inputs[0],
                                         flow_warp_op.inputs[1],
                                         gradients)