summaryrefslogtreecommitdiff
path: root/Codes/inference-test.py
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/inference-test.py')
-rw-r--r--Codes/inference-test.py155
1 files changed, 155 insertions, 0 deletions
diff --git a/Codes/inference-test.py b/Codes/inference-test.py
new file mode 100644
index 0000000..b16a56c
--- /dev/null
+++ b/Codes/inference-test.py
@@ -0,0 +1,155 @@
+import tensorflow as tf
+import os
+import time
+import numpy as np
+import pickle
+
+
+from models import generator
+from utils import DataLoader, load, save, psnr_error
+from constant import const
+import evaluate
+
+
+slim = tf.contrib.slim
+
+os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
+os.environ['CUDA_VISIBLE_DEVICES'] = const.GPU
+
+dataset_name = const.DATASET
+test_folder = const.TEST_FOLDER
+
+num_his = const.NUM_HIS
+height, width = 256, 256
+
+snapshot_dir = const.SNAPSHOT_DIR
+psnr_dir = const.PSNR_DIR
+evaluate_name = const.EVALUATE
+
+print(const)
+
+
+# define dataset
+with tf.name_scope('dataset'):
+ test_video_clips_tensor = tf.placeholder(shape=[1, height, width, 3 * (num_his + 1)],
+ dtype=tf.float32)
+ test_inputs = test_video_clips_tensor[..., 0:num_his*3]
+ test_gt = test_video_clips_tensor[..., -3:]
+ print('test inputs = {}'.format(test_inputs))
+ print('test prediction gt = {}'.format(test_gt))
+
+# define testing generator function and
+# in testing, only generator networks, there is no discriminator networks and flownet.
+with tf.variable_scope('generator', reuse=None):
+ print('testing = {}'.format(tf.get_variable_scope().name))
+ test_outputs = generator(test_inputs, layers=4, output_channel=3)
+ test_psnr_error = psnr_error(gen_frames=test_outputs, gt_frames=test_gt)
+
+
+config = tf.ConfigProto()
+config.gpu_options.allow_growth = True
+with tf.Session(config=config) as sess:
+ # dataset
+ data_loader = DataLoader(test_folder, height, width)
+
+ # initialize weights
+ sess.run(tf.global_variables_initializer())
+ print('Init global successfully!')
+
+ # tf saver
+ saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=None)
+
+ restore_var = [v for v in tf.global_variables()]
+ loader = tf.train.Saver(var_list=restore_var)
+
+ def inference_func(ckpt, dataset_name, evaluate_name):
+ load(loader, sess, ckpt)
+
+ output_records = []
+ videos_info = data_loader.videos
+ num_videos = len(videos_info.keys())
+ total = 0
+ timestamp = time.time()
+
+ for video_name, video in videos_info.items():
+ length = video['length']
+ total += length
+ psnrs = np.empty(shape=(length,), dtype=np.float32)
+
+ for i in range(num_his, length):
+ video_clip = data_loader.get_video_clips(video_name, i - num_his, i + 1)
+ output, psnr = sess.run([test_outputs, test_psnr_error]
+ feed_dict={test_video_clips_tensor: video_clip[np.newaxis, ...]})
+ outputs[i] = output
+
+ tf.image.encode_png(
+ output,
+ compression=-1,
+ name=None
+ )
+
+ print('video = {} / {}, i = {} / {}, psnr = {:.6f}'.format(
+ video_name, num_videos, i, length, psnr))
+
+ outputs[0:num_his] = outputs[num_his]
+ output_records.append(outputs)
+
+ result_dict = {'dataset': dataset_name, 'output': output_records, 'flow': [], 'names': [], 'diff_mask': []}
+
+ used_time = time.time() - timestamp
+ print('total time = {}, fps = {}'.format(used_time, total / used_time))
+
+ # TODO specify what's the actual name of ckpt.
+ pickle_path = os.path.join(output_dir, os.path.split(ckpt)[-1])
+ with open(pickle_path, 'wb') as writer:
+ pickle.dump(result_dict, writer, pickle.HIGHEST_PROTOCOL)
+
+ # results = evaluate.evaluate(evaluate_name, pickle_path)
+ # print(results)
+
+
+ if os.path.isdir(snapshot_dir):
+ def check_ckpt_valid(ckpt_name):
+ is_valid = False
+ ckpt = ''
+ if ckpt_name.startswith('model.ckpt-'):
+ ckpt_name_splits = ckpt_name.split('.')
+ ckpt = str(ckpt_name_splits[0]) + '.' + str(ckpt_name_splits[1])
+ ckpt_path = os.path.join(snapshot_dir, ckpt)
+ if os.path.exists(ckpt_path + '.index') and os.path.exists(ckpt_path + '.meta') and \
+ os.path.exists(ckpt_path + '.data-00000-of-00001'):
+ is_valid = True
+
+ return is_valid, ckpt
+
+ def scan_psnr_folder():
+ tested_ckpt_in_psnr_sets = set()
+ for test_psnr in os.listdir(psnr_dir):
+ tested_ckpt_in_psnr_sets.add(test_psnr)
+ return tested_ckpt_in_psnr_sets
+
+ def scan_model_folder():
+ saved_models = set()
+ for ckpt_name in os.listdir(snapshot_dir):
+ is_valid, ckpt = check_ckpt_valid(ckpt_name)
+ if is_valid:
+ saved_models.add(ckpt)
+ return saved_models
+
+ tested_ckpt_sets = scan_psnr_folder()
+ while True:
+ all_model_ckpts = scan_model_folder()
+ new_model_ckpts = all_model_ckpts - tested_ckpt_sets
+
+ for ckpt_name in new_model_ckpts:
+ # inference
+ ckpt = os.path.join(snapshot_dir, ckpt_name)
+ inference_func(ckpt, dataset_name, evaluate_name)
+
+ tested_ckpt_sets.add(ckpt_name)
+
+ print('waiting for models...')
+ # evaluate.evaluate('compute_auc', psnr_dir)
+ time.sleep(60)
+ else:
+ inference_func(snapshot_dir, dataset_name, evaluate_name)