From 91e5f1ffb152e1b729fe9d530d9f01e73017abbf Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Sat, 5 May 2018 19:07:10 +0200 Subject: inference test --- Codes/inference-test.py | 155 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 Codes/inference-test.py (limited to 'Codes/inference-test.py') diff --git a/Codes/inference-test.py b/Codes/inference-test.py new file mode 100644 index 0000000..b16a56c --- /dev/null +++ b/Codes/inference-test.py @@ -0,0 +1,155 @@ +import tensorflow as tf +import os +import time +import numpy as np +import pickle + + +from models import generator +from utils import DataLoader, load, save, psnr_error +from constant import const +import evaluate + + +slim = tf.contrib.slim + +os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID" +os.environ['CUDA_VISIBLE_DEVICES'] = const.GPU + +dataset_name = const.DATASET +test_folder = const.TEST_FOLDER + +num_his = const.NUM_HIS +height, width = 256, 256 + +snapshot_dir = const.SNAPSHOT_DIR +psnr_dir = const.PSNR_DIR +evaluate_name = const.EVALUATE + +print(const) + + +# define dataset +with tf.name_scope('dataset'): + test_video_clips_tensor = tf.placeholder(shape=[1, height, width, 3 * (num_his + 1)], + dtype=tf.float32) + test_inputs = test_video_clips_tensor[..., 0:num_his*3] + test_gt = test_video_clips_tensor[..., -3:] + print('test inputs = {}'.format(test_inputs)) + print('test prediction gt = {}'.format(test_gt)) + +# define testing generator function and +# in testing, only generator networks, there is no discriminator networks and flownet. +with tf.variable_scope('generator', reuse=None): + print('testing = {}'.format(tf.get_variable_scope().name)) + test_outputs = generator(test_inputs, layers=4, output_channel=3) + test_psnr_error = psnr_error(gen_frames=test_outputs, gt_frames=test_gt) + + +config = tf.ConfigProto() +config.gpu_options.allow_growth = True +with tf.Session(config=config) as sess: + # dataset + data_loader = DataLoader(test_folder, height, width) + + # initialize weights + sess.run(tf.global_variables_initializer()) + print('Init global successfully!') + + # tf saver + saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=None) + + restore_var = [v for v in tf.global_variables()] + loader = tf.train.Saver(var_list=restore_var) + + def inference_func(ckpt, dataset_name, evaluate_name): + load(loader, sess, ckpt) + + output_records = [] + videos_info = data_loader.videos + num_videos = len(videos_info.keys()) + total = 0 + timestamp = time.time() + + for video_name, video in videos_info.items(): + length = video['length'] + total += length + psnrs = np.empty(shape=(length,), dtype=np.float32) + + for i in range(num_his, length): + video_clip = data_loader.get_video_clips(video_name, i - num_his, i + 1) + output, psnr = sess.run([test_outputs, test_psnr_error] + feed_dict={test_video_clips_tensor: video_clip[np.newaxis, ...]}) + outputs[i] = output + + tf.image.encode_png( + output, + compression=-1, + name=None + ) + + print('video = {} / {}, i = {} / {}, psnr = {:.6f}'.format( + video_name, num_videos, i, length, psnr)) + + outputs[0:num_his] = outputs[num_his] + output_records.append(outputs) + + result_dict = {'dataset': dataset_name, 'output': output_records, 'flow': [], 'names': [], 'diff_mask': []} + + used_time = time.time() - timestamp + print('total time = {}, fps = {}'.format(used_time, total / used_time)) + + # TODO specify what's the actual name of ckpt. + pickle_path = os.path.join(output_dir, os.path.split(ckpt)[-1]) + with open(pickle_path, 'wb') as writer: + pickle.dump(result_dict, writer, pickle.HIGHEST_PROTOCOL) + + # results = evaluate.evaluate(evaluate_name, pickle_path) + # print(results) + + + if os.path.isdir(snapshot_dir): + def check_ckpt_valid(ckpt_name): + is_valid = False + ckpt = '' + if ckpt_name.startswith('model.ckpt-'): + ckpt_name_splits = ckpt_name.split('.') + ckpt = str(ckpt_name_splits[0]) + '.' + str(ckpt_name_splits[1]) + ckpt_path = os.path.join(snapshot_dir, ckpt) + if os.path.exists(ckpt_path + '.index') and os.path.exists(ckpt_path + '.meta') and \ + os.path.exists(ckpt_path + '.data-00000-of-00001'): + is_valid = True + + return is_valid, ckpt + + def scan_psnr_folder(): + tested_ckpt_in_psnr_sets = set() + for test_psnr in os.listdir(psnr_dir): + tested_ckpt_in_psnr_sets.add(test_psnr) + return tested_ckpt_in_psnr_sets + + def scan_model_folder(): + saved_models = set() + for ckpt_name in os.listdir(snapshot_dir): + is_valid, ckpt = check_ckpt_valid(ckpt_name) + if is_valid: + saved_models.add(ckpt) + return saved_models + + tested_ckpt_sets = scan_psnr_folder() + while True: + all_model_ckpts = scan_model_folder() + new_model_ckpts = all_model_ckpts - tested_ckpt_sets + + for ckpt_name in new_model_ckpts: + # inference + ckpt = os.path.join(snapshot_dir, ckpt_name) + inference_func(ckpt, dataset_name, evaluate_name) + + tested_ckpt_sets.add(ckpt_name) + + print('waiting for models...') + # evaluate.evaluate('compute_auc', psnr_dir) + time.sleep(60) + else: + inference_func(snapshot_dir, dataset_name, evaluate_name) -- cgit v1.2.3-70-g09d2