summaryrefslogtreecommitdiff
path: root/test.py
diff options
context:
space:
mode:
authortingchunw <tingchunw@nvidia.com>2017-12-04 16:52:46 -0800
committertingchunw <tingchunw@nvidia.com>2017-12-04 16:52:46 -0800
commit9054cf9b0c327a5077fd0793abe178f400da3315 (patch)
tree3c69c07bdcba86c47d8442648fd69c0434e04136 /test.py
parentf9e9999541d67a908a169cc88407675133130e1f (diff)
first commit
Diffstat (limited to 'test.py')
-rwxr-xr-xtest.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/test.py b/test.py
new file mode 100755
index 0000000..d96ac10
--- /dev/null
+++ b/test.py
@@ -0,0 +1,37 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import time
+import os
+from collections import OrderedDict
+from options.test_options import TestOptions
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+import util.util as util
+from util.visualizer import Visualizer
+from util import html
+
+opt = TestOptions().parse(save=False)
+opt.nThreads = 1 # test code only supports nThreads = 1
+opt.batchSize = 1 # test code only supports batchSize = 1
+opt.serial_batches = True # no shuffle
+opt.no_flip = True # no flip
+
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+model = create_model(opt)
+visualizer = Visualizer(opt)
+# create website
+web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
+webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
+# test
+for i, data in enumerate(dataset):
+ if i >= opt.how_many:
+ break
+ generated = model.inference(data['label'], data['inst'])
+ visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
+ ('synthesized_image', util.tensor2im(generated.data[0]))])
+ img_path = data['path']
+ print('process image... %s' % img_path)
+ visualizer.save_images(webpage, visuals, img_path)
+
+webpage.save()