summaryrefslogtreecommitdiff
path: root/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'test.py')
-rw-r--r--test.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..22d092c
--- /dev/null
+++ b/test.py
@@ -0,0 +1,34 @@
+import time
+import os
+from options.test_options import TestOptions
+opt = TestOptions().parse() # set CUDA_VISIBLE_DEVICES before import torch
+
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+from util.visualizer import Visualizer
+from pdb import set_trace as st
+from util import html
+
+opt.nThreads = 1 # test code only supports nThreads=1
+opt.batchSize = 1 #test code only supports batchSize=1
+opt.serial_batches = True # no shuffle
+
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+model = create_model(opt)
+visualizer = Visualizer(opt)
+# create website
+web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
+webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
+# test
+for i, data in enumerate(dataset):
+ if i >= opt.how_many:
+ break
+ model.set_input(data)
+ model.test()
+ visuals = model.get_current_visuals()
+ img_path = model.get_image_paths()
+ print('process image... %s' % img_path)
+ visualizer.save_images(webpage, visuals, img_path)
+
+webpage.save()