diff options
| author | StevenLiuWen <liuwen@shanghaitech.edu.cn> | 2018-03-13 03:28:06 -0400 |
|---|---|---|
| committer | StevenLiuWen <liuwen@shanghaitech.edu.cn> | 2018-03-13 03:28:06 -0400 |
| commit | fede6ca1dd0077ff509d84bd24028cc7a93bb119 (patch) | |
| tree | af7f6e759b5dec4fc2964daed09e903958b919ed /Codes/models.py | |
first commit
Diffstat (limited to 'Codes/models.py')
| -rw-r--r-- | Codes/models.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/Codes/models.py b/Codes/models.py new file mode 100644 index 0000000..8c20134 --- /dev/null +++ b/Codes/models.py @@ -0,0 +1,44 @@ +import tensorflow as tf + +import unet +import pix2pix + +from flownet2.src.flowlib import flow_to_image +from flownet2.src.flownet_sd.flownet_sd import FlowNetSD # Ok +from flownet2.src.training_schedules import LONG_SCHEDULE +from flownet2.src.net import Mode + + +slim = tf.contrib.slim + + +def generator(inputs, layers, features_root=64, filter_size=3, pool_size=2, output_channel=3): + return unet.unet(inputs, layers, features_root, filter_size, pool_size, output_channel) + + +def discriminator(inputs, num_filers=(128, 256, 512, 512)): + logits, end_points = pix2pix.pix2pix_discriminator(inputs, num_filers) + return logits, end_points['predictions'] + + +def flownet(input_a, input_b, height, width, reuse=None): + net = FlowNetSD(mode=Mode.TEST) + # train preds flow + input_a = (input_a + 1.0) / 2.0 # flownet receives image with color space in [0, 1] + input_b = (input_b + 1.0) / 2.0 # flownet receives image with color space in [0, 1] + # input size is 384 x 512 + input_a = tf.image.resize_images(input_a, [height, width]) + input_b = tf.image.resize_images(input_b, [height, width]) + flows = net.model( + inputs={'input_a': input_a, 'input_b': input_b}, + training_schedule=LONG_SCHEDULE, + trainable=False, reuse=reuse + ) + return flows['flow'] + + +def initialize_flownet(sess, checkpoint): + flownet_vars = slim.get_variables_to_restore(include=['FlowNetSD']) + flownet_saver = tf.train.Saver(flownet_vars) + print('FlownetSD restore from {}!'.format(checkpoint)) + flownet_saver.restore(sess, checkpoint) |
