summaryrefslogtreecommitdiff
path: root/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'test.py')
-rwxr-xr-xtest.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/test.py b/test.py
index a9c8729..1effb08 100755
--- a/test.py
+++ b/test.py
@@ -26,6 +26,20 @@ webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.na
for i, data in enumerate(dataset):
if i >= opt.how_many:
break
+ if opt.data_type == 16:
+ model.half()
+ data['label'] = data['label'].half()
+ data['inst'] = data['inst'].half()
+ elif opt.data_type == 8:
+ model.type(torch.uint8)
+
+ if opt.export_onnx:
+ assert opt.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
+ if opt.verbose:
+ print(model)
+ generated = torch.onnx.export(model, [data['label'], data['inst']],
+ opt.export_onnx, verbose=True)
+
generated = model.inference(data['label'], data['inst'])
visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
('synthesized_image', util.tensor2im(generated.data[0]))])