diff options
| author | jules@lens <julescarbon@gmail.com> | 2018-05-05 19:32:25 +0200 |
|---|---|---|
| committer | jules@lens <julescarbon@gmail.com> | 2018-05-05 19:32:25 +0200 |
| commit | 93df56d88da5e2390e0d542ed37a57d48f3f3105 (patch) | |
| tree | 9144319dbe9fb2f70f03fea0c68cf043db38c585 /Codes/inference-test.py | |
| parent | 91e5f1ffb152e1b729fe9d530d9f01e73017abbf (diff) | |
Diffstat (limited to 'Codes/inference-test.py')
| -rw-r--r-- | Codes/inference-test.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/Codes/inference-test.py b/Codes/inference-test.py index b16a56c..83a4fd4 100644 --- a/Codes/inference-test.py +++ b/Codes/inference-test.py @@ -10,6 +10,7 @@ from utils import DataLoader, load, save, psnr_error from constant import const import evaluate +from PIL import Image slim = tf.contrib.slim @@ -66,6 +67,7 @@ with tf.Session(config=config) as sess: load(loader, sess, ckpt) output_records = [] + outputs = [] videos_info = data_loader.videos num_videos = len(videos_info.keys()) total = 0 @@ -78,15 +80,19 @@ with tf.Session(config=config) as sess: for i in range(num_his, length): video_clip = data_loader.get_video_clips(video_name, i - num_his, i + 1) - output, psnr = sess.run([test_outputs, test_psnr_error] + output, psnr = sess.run([test_outputs, test_psnr_error], feed_dict={test_video_clips_tensor: video_clip[np.newaxis, ...]}) - outputs[i] = output + outputs.append(output) - tf.image.encode_png( - output, - compression=-1, - name=None - ) + #tf.image.encode_png( + # output, + # compression=-1, + # name=None + #) + ohhh = (output[0,...] + 1.0) * 127.5 + print(ohhh.shape, np.amin(ohhh), np.amax(ohhh)) + out = Image.fromarray(ohhh, 'RGB') + out.save(os.path.join(os.getcwd(), '{}_{:05d}.png'.format(video_name, i))) print('video = {} / {}, i = {} / {}, psnr = {:.6f}'.format( video_name, num_videos, i, length, psnr)) @@ -108,6 +114,7 @@ with tf.Session(config=config) as sess: # print(results) + print(os.path.isdir(snapshot_dir)) if os.path.isdir(snapshot_dir): def check_ckpt_valid(ckpt_name): is_valid = False |
