summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test.py60
1 files changed, 42 insertions, 18 deletions
diff --git a/test.py b/test.py
index aaa2ff5..0b6aa7e 100644
--- a/test.py
+++ b/test.py
@@ -6,15 +6,18 @@ import glob
import scipy
import argparse
import os
+import subprocess
from PIL import Image
from utils import LoadImage, DownSample, AVG_PSNR, depth_to_space_3D, DynFilter3D, LoadParams
from nets import FR_16L, FR_28L, FR_52L
parser = argparse.ArgumentParser()
-parser.add_argument('L', metavar='L', type=int, help='Network depth: One of 16, 28, 52')
-parser.add_argument('T', metavar='T', help='Input type: L(Low-resolution) or G(Ground-truth)')
-parser.add_argument('dir', metavar='dir', default=None, help='Directory to process')
+parser.add_argument('--depth', metavar='L', type=int, default=28, help='Network depth: One of 16, 28, 52')
+parser.add_argument('--type', metavar='T', default='L', help='Input type: L(Low-resolution) or G(Ground-truth)')
+parser.add_argument('--in_dir', metavar='in_dir', default=None, help='Directory to process')
+parser.add_argument('--out_dir', metavar='out_dir', default='/media/blue/uprez', default=None, help='Directory to process')
+parser.add_argument('--no_mov', action='store_true')
args = parser.parse_args()
# Size of input temporal radius
@@ -36,12 +39,17 @@ if not(args.T == 'L' or args.T =='G'):
print('Invalid input type: {} (Must be L(Low-resolution) or G(Ground-truth))'.format(args.T))
exit(1)
-def process_dir(v, scene_name=None):
- if scene_name is None:
- scene_name = v.split('/')[-1]
- os.makedirs('./results/{}L/L/{}/'.format(args.L, scene_name))
+def process_dir(dir):
+ dir_partz = dir.split('/')
+ dataset = dir_partz[-2]
+ part = dir_partz[-1]
+ tag = '_'.join([dataset, args.L + 'L', part])
+ out_path = os.path.join(args.out_dir, 'results', dataset, args.L + 'L', part)
+ render_path = os.path.join(args.out_dir, 'renders')
+ os.makedirs(out_path)
+ os.makedirs(render_path)
- dir_frames = sorted(glob.glob(v + '/*.png'))
+ dir_frames = sorted(glob.glob(os.path.join(dir + '*.png')))
print(dir_frames)
frames = []
@@ -58,7 +66,21 @@ def process_dir(v, scene_name=None):
out_H = sess.run(GH, feed_dict={L: in_L, is_train: False})
out_H = np.clip(out_H, 0, 1)
- Image.fromarray(np.around(out_H[0,0]*255).astype(np.uint8)).save('./results/{}L/L/{}/frame_{:05d}.png'.format(args.L, scene_name, i+1))
+ Image.fromarray(np.around(out_H[0,0]*255).astype(np.uint8)).save('{}/frame_{:05d}.png'.format(out_path, i+1))
+
+ if not args.no_mov:
+ subprocess.call([
+ 'ffmpeg',
+ '-hide_banner',
+ '-i', os.path.join(out_path, 'frame_%05d.png'),
+ '-y',
+ '-c:v', 'libx264',
+ '-preset', 'slow',
+ '-crf', '19',
+ '-vf', 'fps=25',
+ '-pix_fmt', 'yuv420p',
+ os.path.join(render_path, tag + '.mp4')
+ ])
def G(x, is_train):
# shape of x: [B,T_in,H,W,C]
@@ -115,8 +137,17 @@ with tf.Session(config=config) as sess:
# Load parameters
LoadParams(sess, [params_G], in_file='params_{}L_x{}.h5'.format(args.L, R))
-
- if args.T == 'G':
+
+ if args.T == 'L':
+ # Test using Low-resolution videos
+ if args.dir:
+ for dir in sorted(glob.glob(os.path.join(args.in_dir, '*'))):
+ process_dir(dir)
+ else:
+ for dir in sorted(glob.glob('./inputs/L/*')):
+ process_dir(dir)
+
+ elif args.T == 'G':
# Test using GT videos
avg_psnrs = []
dir_inputs = glob.glob('./inputs/G/*')
@@ -152,10 +183,3 @@ with tf.Session(config=config) as sess:
avg_psnrs.append(avg_psnr)
print('Scene {}: PSNR {}'.format(scene_name, avg_psnr))
- elif args.T == 'L':
- # Test using Low-resolution videos
- if args.dir:
- process_dir(args.dir, scene_name=args.dir.replace('.', '').replace('/', '_'))
- else:
- for v in sorted(glob.glob('./inputs/L/*')):
- process_dir(v)