diff options
| author | Ting-Chun Wang <tcwang0509@berkeley.edu> | 2018-05-30 22:39:01 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2018-05-30 22:39:01 -0700 |
| commit | a2340c3fff9de44c8ef1fea5b90fced756fbbb18 (patch) | |
| tree | 39f3c05c80a94d721ec6fed0f0da65ecbc3bc603 /models/models.py | |
| parent | 1b89cd010dce2e6edaa07d23c8edd8dfe146e0e1 (diff) | |
| parent | 25e205604e7eafa83867a15cfda526461fe58455 (diff) | |
Merge pull request #33 from borisfom/fp16
Added data size and ONNX export options, FP16 inference is working
Diffstat (limited to 'models/models.py')
| -rwxr-xr-x | models/models.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/models/models.py b/models/models.py index 0ba442f..8e72e46 100755 --- a/models/models.py +++ b/models/models.py @@ -4,13 +4,17 @@ import torch def create_model(opt): if opt.model == 'pix2pixHD': - from .pix2pixHD_model import Pix2PixHDModel - model = Pix2PixHDModel() + from .pix2pixHD_model import Pix2PixHDModel, InferenceModel + if opt.isTrain: + model = Pix2PixHDModel() + else: + model = InferenceModel() else: from .ui_model import UIModel model = UIModel() model.initialize(opt) - print("model [%s] was created" % (model.name())) + if opt.verbose: + print("model [%s] was created" % (model.name())) if opt.isTrain and len(opt.gpu_ids): model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) |
