summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--test.py15
2 files changed, 10 insertions, 7 deletions
diff --git a/.gitignore b/.gitignore
index 731c151..7ace067 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,3 +10,5 @@ results
*.mp4
*.png
+thirds
+
diff --git a/test.py b/test.py
index 0b6aa7e..c668529 100644
--- a/test.py
+++ b/test.py
@@ -13,10 +13,11 @@ from utils import LoadImage, DownSample, AVG_PSNR, depth_to_space_3D, DynFilter3
from nets import FR_16L, FR_28L, FR_52L
parser = argparse.ArgumentParser()
-parser.add_argument('--depth', metavar='L', type=int, default=28, help='Network depth: One of 16, 28, 52')
-parser.add_argument('--type', metavar='T', default='L', help='Input type: L(Low-resolution) or G(Ground-truth)')
+parser.add_argument('--L', metavar='L', type=int, default=28, help='Network depth: One of 16, 28, 52')
+parser.add_argument('--T', metavar='T', default='L', help='Input type: L(Low-resolution) or G(Ground-truth)')
parser.add_argument('--in_dir', metavar='in_dir', default=None, help='Directory to process')
-parser.add_argument('--out_dir', metavar='out_dir', default='/media/blue/uprez', default=None, help='Directory to process')
+parser.add_argument('--out_dir', metavar='out_dir', default='/media/blue/uprez', help='Directory to output to')
+parser.add_argument('--network_dir', default='.', help='Path to networks')
parser.add_argument('--no_mov', action='store_true')
args = parser.parse_args()
@@ -43,8 +44,8 @@ def process_dir(dir):
dir_partz = dir.split('/')
dataset = dir_partz[-2]
part = dir_partz[-1]
- tag = '_'.join([dataset, args.L + 'L', part])
- out_path = os.path.join(args.out_dir, 'results', dataset, args.L + 'L', part)
+ tag = '_'.join([dataset, str(args.L) + 'L', part])
+ out_path = os.path.join(args.out_dir, 'results', dataset, str(args.L) + 'L', part)
render_path = os.path.join(args.out_dir, 'renders')
os.makedirs(out_path)
os.makedirs(render_path)
@@ -136,11 +137,11 @@ with tf.Session(config=config) as sess:
tf.global_variables_initializer().run()
# Load parameters
- LoadParams(sess, [params_G], in_file='params_{}L_x{}.h5'.format(args.L, R))
+ LoadParams(sess, [params_G], in_file=os.path.join(args.network_dir, 'params_{}L_x{}.h5'.format(args.L, R)))
if args.T == 'L':
# Test using Low-resolution videos
- if args.dir:
+ if args.in_dir:
for dir in sorted(glob.glob(os.path.join(args.in_dir, '*'))):
process_dir(dir)
else: