summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Codes/inference-test.py155
-rw-r--r--Codes/test.py196
2 files changed, 351 insertions, 0 deletions
diff --git a/Codes/inference-test.py b/Codes/inference-test.py
new file mode 100644
index 0000000..b16a56c
--- /dev/null
+++ b/Codes/inference-test.py
@@ -0,0 +1,155 @@
+import tensorflow as tf
+import os
+import time
+import numpy as np
+import pickle
+
+
+from models import generator
+from utils import DataLoader, load, save, psnr_error
+from constant import const
+import evaluate
+
+
+slim = tf.contrib.slim
+
+os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
+os.environ['CUDA_VISIBLE_DEVICES'] = const.GPU
+
+dataset_name = const.DATASET
+test_folder = const.TEST_FOLDER
+
+num_his = const.NUM_HIS
+height, width = 256, 256
+
+snapshot_dir = const.SNAPSHOT_DIR
+psnr_dir = const.PSNR_DIR
+evaluate_name = const.EVALUATE
+
+print(const)
+
+
+# define dataset
+with tf.name_scope('dataset'):
+ test_video_clips_tensor = tf.placeholder(shape=[1, height, width, 3 * (num_his + 1)],
+ dtype=tf.float32)
+ test_inputs = test_video_clips_tensor[..., 0:num_his*3]
+ test_gt = test_video_clips_tensor[..., -3:]
+ print('test inputs = {}'.format(test_inputs))
+ print('test prediction gt = {}'.format(test_gt))
+
+# define testing generator function and
+# in testing, only generator networks, there is no discriminator networks and flownet.
+with tf.variable_scope('generator', reuse=None):
+ print('testing = {}'.format(tf.get_variable_scope().name))
+ test_outputs = generator(test_inputs, layers=4, output_channel=3)
+ test_psnr_error = psnr_error(gen_frames=test_outputs, gt_frames=test_gt)
+
+
+config = tf.ConfigProto()
+config.gpu_options.allow_growth = True
+with tf.Session(config=config) as sess:
+ # dataset
+ data_loader = DataLoader(test_folder, height, width)
+
+ # initialize weights
+ sess.run(tf.global_variables_initializer())
+ print('Init global successfully!')
+
+ # tf saver
+ saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=None)
+
+ restore_var = [v for v in tf.global_variables()]
+ loader = tf.train.Saver(var_list=restore_var)
+
+ def inference_func(ckpt, dataset_name, evaluate_name):
+ load(loader, sess, ckpt)
+
+ output_records = []
+ videos_info = data_loader.videos
+ num_videos = len(videos_info.keys())
+ total = 0
+ timestamp = time.time()
+
+ for video_name, video in videos_info.items():
+ length = video['length']
+ total += length
+ psnrs = np.empty(shape=(length,), dtype=np.float32)
+
+ for i in range(num_his, length):
+ video_clip = data_loader.get_video_clips(video_name, i - num_his, i + 1)
+ output, psnr = sess.run([test_outputs, test_psnr_error]
+ feed_dict={test_video_clips_tensor: video_clip[np.newaxis, ...]})
+ outputs[i] = output
+
+ tf.image.encode_png(
+ output,
+ compression=-1,
+ name=None
+ )
+
+ print('video = {} / {}, i = {} / {}, psnr = {:.6f}'.format(
+ video_name, num_videos, i, length, psnr))
+
+ outputs[0:num_his] = outputs[num_his]
+ output_records.append(outputs)
+
+ result_dict = {'dataset': dataset_name, 'output': output_records, 'flow': [], 'names': [], 'diff_mask': []}
+
+ used_time = time.time() - timestamp
+ print('total time = {}, fps = {}'.format(used_time, total / used_time))
+
+ # TODO specify what's the actual name of ckpt.
+ pickle_path = os.path.join(output_dir, os.path.split(ckpt)[-1])
+ with open(pickle_path, 'wb') as writer:
+ pickle.dump(result_dict, writer, pickle.HIGHEST_PROTOCOL)
+
+ # results = evaluate.evaluate(evaluate_name, pickle_path)
+ # print(results)
+
+
+ if os.path.isdir(snapshot_dir):
+ def check_ckpt_valid(ckpt_name):
+ is_valid = False
+ ckpt = ''
+ if ckpt_name.startswith('model.ckpt-'):
+ ckpt_name_splits = ckpt_name.split('.')
+ ckpt = str(ckpt_name_splits[0]) + '.' + str(ckpt_name_splits[1])
+ ckpt_path = os.path.join(snapshot_dir, ckpt)
+ if os.path.exists(ckpt_path + '.index') and os.path.exists(ckpt_path + '.meta') and \
+ os.path.exists(ckpt_path + '.data-00000-of-00001'):
+ is_valid = True
+
+ return is_valid, ckpt
+
+ def scan_psnr_folder():
+ tested_ckpt_in_psnr_sets = set()
+ for test_psnr in os.listdir(psnr_dir):
+ tested_ckpt_in_psnr_sets.add(test_psnr)
+ return tested_ckpt_in_psnr_sets
+
+ def scan_model_folder():
+ saved_models = set()
+ for ckpt_name in os.listdir(snapshot_dir):
+ is_valid, ckpt = check_ckpt_valid(ckpt_name)
+ if is_valid:
+ saved_models.add(ckpt)
+ return saved_models
+
+ tested_ckpt_sets = scan_psnr_folder()
+ while True:
+ all_model_ckpts = scan_model_folder()
+ new_model_ckpts = all_model_ckpts - tested_ckpt_sets
+
+ for ckpt_name in new_model_ckpts:
+ # inference
+ ckpt = os.path.join(snapshot_dir, ckpt_name)
+ inference_func(ckpt, dataset_name, evaluate_name)
+
+ tested_ckpt_sets.add(ckpt_name)
+
+ print('waiting for models...')
+ # evaluate.evaluate('compute_auc', psnr_dir)
+ time.sleep(60)
+ else:
+ inference_func(snapshot_dir, dataset_name, evaluate_name)
diff --git a/Codes/test.py b/Codes/test.py
new file mode 100644
index 0000000..57ee298
--- /dev/null
+++ b/Codes/test.py
@@ -0,0 +1,196 @@
+import tensorflow as tf
+import os
+
+from models import generator, discriminator, flownet, initialize_flownet
+from loss_functions import intensity_loss, gradient_loss
+from utils import DataLoader, load, save, psnr_error
+from constant import const
+
+
+os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
+os.environ['CUDA_VISIBLE_DEVICES'] = const.GPU
+
+dataset_name = const.DATASET
+train_folder = const.TRAIN_FOLDER
+test_folder = const.TEST_FOLDER
+
+batch_size = const.BATCH_SIZE
+iterations = const.ITERATIONS
+num_his = const.NUM_HIS
+height, width = 256, 256
+flow_height, flow_width = const.FLOW_HEIGHT, const.FLOW_WIDTH
+
+l_num = const.L_NUM
+alpha_num = const.ALPHA_NUM
+lam_lp = const.LAM_LP
+lam_gdl = const.LAM_GDL
+lam_adv = const.LAM_ADV
+lam_flow = const.LAM_FLOW
+adversarial = (lam_adv != 0)
+
+summary_dir = const.SUMMARY_DIR
+snapshot_dir = const.SNAPSHOT_DIR
+
+
+print(const)
+
+# define dataset
+with tf.name_scope('dataset'):
+ test_loader = DataLoader(test_folder, resize_height=height, resize_width=width)
+ test_dataset = test_loader(batch_size=batch_size, time_steps=num_his, num_pred=1)
+ test_it = test_dataset.make_one_shot_iterator()
+ test_videos_clips_tensor = test_it.get_next()
+ test_videos_clips_tensor.set_shape([batch_size, height, width, 3*(num_his + 1)])
+
+ test_inputs = test_videos_clips_tensor[..., 0:num_his*3]
+ test_gt = test_videos_clips_tensor[..., -3:]
+
+ print('test inputs = {}'.format(test_inputs))
+ print('test prediction gt = {}'.format(test_gt))
+
+# define testing generator function
+with tf.variable_scope('generator', reuse=True):
+ print('testing = {}'.format(tf.get_variable_scope().name))
+ test_outputs = generator(test_inputs, layers=4, output_channel=3)
+ test_psnr_error = psnr_error(gen_frames=test_outputs, gt_frames=test_gt)
+
+
+# define intensity loss
+if lam_lp != 0:
+ lp_loss = intensity_loss(gen_frames=test_outputs, gt_frames=test_gt, l_num=l_num)
+else:
+ lp_loss = tf.constant(0.0, dtype=tf.float32)
+
+
+# define gdl loss
+if lam_gdl != 0:
+ gdl_loss = gradient_loss(gen_frames=test_outputs, gt_frames=test_gt, alpha=alpha_num)
+else:
+ gdl_loss = tf.constant(0.0, dtype=tf.float32)
+
+
+# define flow loss
+if lam_flow != 0:
+ test_gt_flow = flownet(input_a=train_inputs[..., -3:], input_b=train_gt,
+ height=flow_height, width=flow_width, reuse=None)
+ train_pred_flow = flownet(input_a=train_inputs[..., -3:], input_b=train_outputs,
+ height=flow_height, width=flow_width, reuse=True)
+ flow_loss = tf.reduce_mean(tf.abs(train_gt_flow - train_pred_flow))
+else:
+ flow_loss = tf.constant(0.0, dtype=tf.float32)
+
+
+# define adversarial loss
+if adversarial:
+ with tf.variable_scope('discriminator', reuse=None):
+ real_logits, real_outputs = discriminator(inputs=train_gt)
+ with tf.variable_scope('discriminator', reuse=True):
+ fake_logits, fake_outputs = discriminator(inputs=train_outputs)
+
+ print('real_outputs = {}'.format(real_outputs))
+ print('fake_outputs = {}'.format(fake_outputs))
+
+ adv_loss = tf.reduce_mean(tf.square(fake_outputs - 1) / 2)
+ dis_loss = tf.reduce_mean(tf.square(real_outputs - 1) / 2) + tf.reduce_mean(tf.square(fake_outputs) / 2)
+else:
+ adv_loss = tf.constant(0.0, dtype=tf.float32)
+ dis_loss = tf.constant(0.0, dtype=tf.float32)
+
+
+with tf.name_scope('training'):
+ g_loss = tf.add_n([lp_loss * lam_lp, gdl_loss * lam_gdl, adv_loss * lam_adv, flow_loss * lam_flow], name='g_loss')
+
+ g_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='g_step')
+ g_lrate = tf.train.piecewise_constant(g_step, boundaries=const.LRATE_G_BOUNDARIES, values=const.LRATE_G)
+ g_optimizer = tf.train.AdamOptimizer(learning_rate=g_lrate, name='g_optimizer')
+ g_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
+
+ g_train_op = g_optimizer.minimize(g_loss, global_step=g_step, var_list=g_vars, name='g_train_op')
+
+ if adversarial:
+ # training discriminator
+ d_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='d_step')
+ d_lrate = tf.train.piecewise_constant(d_step, boundaries=const.LRATE_D_BOUNDARIES, values=const.LRATE_D)
+ d_optimizer = tf.train.AdamOptimizer(learning_rate=d_lrate, name='g_optimizer')
+ d_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
+
+ d_train_op = d_optimizer.minimize(dis_loss, global_step=d_step, var_list=d_vars, name='d_optimizer')
+ else:
+ d_step = None
+ d_lrate = None
+ d_train_op = None
+
+# add all to summaries
+tf.summary.scalar(tensor=train_psnr_error, name='train_psnr_error')
+tf.summary.scalar(tensor=test_psnr_error, name='test_psnr_error')
+tf.summary.scalar(tensor=g_loss, name='g_loss')
+tf.summary.scalar(tensor=adv_loss, name='adv_loss')
+tf.summary.scalar(tensor=dis_loss, name='dis_loss')
+tf.summary.image(tensor=train_outputs, name='train_outputs')
+tf.summary.image(tensor=train_gt, name='train_gt')
+tf.summary.image(tensor=test_outputs, name='test_outputs')
+tf.summary.image(tensor=test_gt, name='test_gt')
+summary_op = tf.summary.merge_all()
+
+config = tf.ConfigProto()
+config.gpu_options.allow_growth = True
+with tf.Session(config=config) as sess:
+ # summaries
+ summary_writer = tf.summary.FileWriter(summary_dir, graph=sess.graph)
+
+ # initialize weights
+ sess.run(tf.global_variables_initializer())
+ print('Init successfully!')
+
+ if lam_flow != 0:
+ # initialize flownet
+ initialize_flownet(sess, const.FLOWNET_CHECKPOINT)
+
+ # tf saver
+ saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=None)
+ restore_var = [v for v in tf.global_variables()]
+ loader = tf.train.Saver(var_list=restore_var)
+ if os.path.isdir(snapshot_dir):
+ ckpt = tf.train.get_checkpoint_state(snapshot_dir)
+ if ckpt and ckpt.model_checkpoint_path:
+ load(loader, sess, ckpt.model_checkpoint_path)
+ else:
+ print('No checkpoint file found.')
+ else:
+ load(loader, sess, snapshot_dir)
+
+ _step, _loss, _summaries = 0, None, None
+ while _step < iterations:
+ try:
+ if adversarial:
+ print('Training discriminator...')
+ _, _d_lr, _d_step, _dis_loss = sess.run([d_train_op, d_lrate, d_step, dis_loss])
+ else:
+ _d_step = 0
+ _d_lr = 0
+ _dis_loss = 0
+
+ print('Training generator...')
+ _, _g_lr, _step, _lp_loss, _gdl_loss, _adv_loss, _flow_loss, _g_loss, _train_psnr, _summaries = sess.run(
+ [g_train_op, g_lrate, g_step, lp_loss, gdl_loss, adv_loss, flow_loss, g_loss, train_psnr_error, summary_op])
+
+ if _step % 10 == 0:
+ print('DiscriminatorModel: Step {} | Global Loss: {:.6f}, lr = {:.6f}'.format(_d_step, _dis_loss, _d_lr))
+ print('GeneratorModel : Step {}, lr = {:.6f}'.format(_step, _g_lr))
+ print(' Global Loss : ', _g_loss)
+ print(' intensity Loss : ({:.4f} * {:.4f} = {:.4f})'.format(_lp_loss, lam_lp, _lp_loss * lam_lp))
+ print(' gradient Loss : ({:.4f} * {:.4f} = {:.4f})'.format( _gdl_loss, lam_gdl, _gdl_loss * lam_gdl))
+ print(' adversarial Loss : ({:.4f} * {:.4f} = {:.4f})'.format(_adv_loss, lam_adv, _adv_loss * lam_adv))
+ print(' flownet Loss : ({:.4f} * {:.4f} = {:.4f})'.format(_flow_loss, lam_flow, _flow_loss * lam_flow))
+ print(' PSNR Error : ', _train_psnr)
+ if _step % 100 == 0:
+ summary_writer.add_summary(_summaries, global_step=_step)
+ print('Save summaries...')
+
+ if _step % 1000 == 0:
+ save(saver, sess, snapshot_dir, _step)
+
+ except tf.errors.OutOfRangeError:
+ print('Finish successfully!')
+ save(saver, sess, snapshot_dir, _step)
+ break