summaryrefslogtreecommitdiff
path: root/test.py
diff options
context:
space:
mode:
authorBoris Fomitchev <bfomitchev@nvidia.com>2018-05-08 00:56:35 -0700
committerBoris Fomitchev <bfomitchev@nvidia.com>2018-05-08 00:56:35 -0700
commit4ca6b1610f9fa65f8bd7d7c15059bfde18a2f02a (patch)
treeec2eeb09cdef6a70ea5612c3e6aa91ed2849414a /test.py
parent736a2dc9afef418820e9c52f4f3b38460360b9f2 (diff)
Added data size and ONNX export options, FP16 inference is working
Diffstat (limited to 'test.py')
-rwxr-xr-xtest.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/test.py b/test.py
index a9c8729..1effb08 100755
--- a/test.py
+++ b/test.py
@@ -26,6 +26,20 @@ webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.na
for i, data in enumerate(dataset):
if i >= opt.how_many:
break
+ if opt.data_type == 16:
+ model.half()
+ data['label'] = data['label'].half()
+ data['inst'] = data['inst'].half()
+ elif opt.data_type == 8:
+ model.type(torch.uint8)
+
+ if opt.export_onnx:
+ assert opt.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
+ if opt.verbose:
+ print(model)
+ generated = torch.onnx.export(model, [data['label'], data['inst']],
+ opt.export_onnx, verbose=True)
+
generated = model.inference(data['label'], data['inst'])
visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
('synthesized_image', util.tensor2im(generated.data[0]))])