diff options
| author | Ting-Chun Wang <tcwang0509@berkeley.edu> | 2018-05-30 22:39:01 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2018-05-30 22:39:01 -0700 |
| commit | a2340c3fff9de44c8ef1fea5b90fced756fbbb18 (patch) | |
| tree | 39f3c05c80a94d721ec6fed0f0da65ecbc3bc603 /test.py | |
| parent | 1b89cd010dce2e6edaa07d23c8edd8dfe146e0e1 (diff) | |
| parent | 25e205604e7eafa83867a15cfda526461fe58455 (diff) | |
Merge pull request #33 from borisfom/fp16
Added data size and ONNX export options, FP16 inference is working
Diffstat (limited to 'test.py')
| -rwxr-xr-x | test.py | 36 |
1 files changed, 34 insertions, 2 deletions
@@ -8,6 +8,8 @@ from models.models import create_model import util.util as util from util.visualizer import Visualizer from util import html +import torch +from run_engine import run_trt_engine, run_onnx opt = TestOptions().parse(save=False) opt.nThreads = 1 # test code only supports nThreads = 1 @@ -17,16 +19,46 @@ opt.no_flip = True # no flip data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() -model = create_model(opt) visualizer = Visualizer(opt) # create website web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) # test + +if not opt.engine and not opt.onnx: + model = create_model(opt) + if opt.data_type == 16: + model.half() + elif opt.data_type == 8: + model.type(torch.uint8) + + if opt.verbose: + print(model) + + for i, data in enumerate(dataset): if i >= opt.how_many: break - generated = model.inference(data['label'], data['inst']) + if opt.data_type == 16: + data['label'] = data['label'].half() + data['inst'] = data['inst'].half() + elif opt.data_type == 8: + data['label'] = data['label'].uint8() + data['inst'] = data['inst'].uint8() + if opt.export_onnx: + print ("Exporting to ONNX: ", opt.export_onnx) + assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" + torch.onnx.export(model, [data['label'], data['inst']], + opt.export_onnx, verbose=True) + exit(0) + minibatch = 1 + if opt.engine: + generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) + elif opt.onnx: + generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) + else: + generated = model.inference(data['label'], data['inst']) + visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), ('synthesized_image', util.tensor2im(generated.data[0]))]) img_path = data['path'] |
