summaryrefslogtreecommitdiff
path: root/Codes/inference-test.py
diff options
context:
space:
mode:
authorjules@lens <julescarbon@gmail.com>2018-05-05 19:32:25 +0200
committerjules@lens <julescarbon@gmail.com>2018-05-05 19:32:25 +0200
commit93df56d88da5e2390e0d542ed37a57d48f3f3105 (patch)
tree9144319dbe9fb2f70f03fea0c68cf043db38c585 /Codes/inference-test.py
parent91e5f1ffb152e1b729fe9d530d9f01e73017abbf (diff)
new scriptsHEADmaster
Diffstat (limited to 'Codes/inference-test.py')
-rw-r--r--Codes/inference-test.py21
1 files changed, 14 insertions, 7 deletions
diff --git a/Codes/inference-test.py b/Codes/inference-test.py
index b16a56c..83a4fd4 100644
--- a/Codes/inference-test.py
+++ b/Codes/inference-test.py
@@ -10,6 +10,7 @@ from utils import DataLoader, load, save, psnr_error
from constant import const
import evaluate
+from PIL import Image
slim = tf.contrib.slim
@@ -66,6 +67,7 @@ with tf.Session(config=config) as sess:
load(loader, sess, ckpt)
output_records = []
+ outputs = []
videos_info = data_loader.videos
num_videos = len(videos_info.keys())
total = 0
@@ -78,15 +80,19 @@ with tf.Session(config=config) as sess:
for i in range(num_his, length):
video_clip = data_loader.get_video_clips(video_name, i - num_his, i + 1)
- output, psnr = sess.run([test_outputs, test_psnr_error]
+ output, psnr = sess.run([test_outputs, test_psnr_error],
feed_dict={test_video_clips_tensor: video_clip[np.newaxis, ...]})
- outputs[i] = output
+ outputs.append(output)
- tf.image.encode_png(
- output,
- compression=-1,
- name=None
- )
+ #tf.image.encode_png(
+ # output,
+ # compression=-1,
+ # name=None
+ #)
+ ohhh = (output[0,...] + 1.0) * 127.5
+ print(ohhh.shape, np.amin(ohhh), np.amax(ohhh))
+ out = Image.fromarray(ohhh, 'RGB')
+ out.save(os.path.join(os.getcwd(), '{}_{:05d}.png'.format(video_name, i)))
print('video = {} / {}, i = {} / {}, psnr = {:.6f}'.format(
video_name, num_videos, i, length, psnr))
@@ -108,6 +114,7 @@ with tf.Session(config=config) as sess:
# print(results)
+ print(os.path.isdir(snapshot_dir))
if os.path.isdir(snapshot_dir):
def check_ckpt_valid(ckpt_name):
is_valid = False