summaryrefslogtreecommitdiff
path: root/models/models.py
diff options
context:
space:
mode:
authorTing-Chun Wang <tcwang0509@berkeley.edu>2018-05-30 22:39:01 -0700
committerGitHub <noreply@github.com>2018-05-30 22:39:01 -0700
commita2340c3fff9de44c8ef1fea5b90fced756fbbb18 (patch)
tree39f3c05c80a94d721ec6fed0f0da65ecbc3bc603 /models/models.py
parent1b89cd010dce2e6edaa07d23c8edd8dfe146e0e1 (diff)
parent25e205604e7eafa83867a15cfda526461fe58455 (diff)
Merge pull request #33 from borisfom/fp16
Added data size and ONNX export options, FP16 inference is working
Diffstat (limited to 'models/models.py')
-rwxr-xr-xmodels/models.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/models/models.py b/models/models.py
index 0ba442f..8e72e46 100755
--- a/models/models.py
+++ b/models/models.py
@@ -4,13 +4,17 @@ import torch
def create_model(opt):
if opt.model == 'pix2pixHD':
- from .pix2pixHD_model import Pix2PixHDModel
- model = Pix2PixHDModel()
+ from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
+ if opt.isTrain:
+ model = Pix2PixHDModel()
+ else:
+ model = InferenceModel()
else:
from .ui_model import UIModel
model = UIModel()
model.initialize(opt)
- print("model [%s] was created" % (model.name()))
+ if opt.verbose:
+ print("model [%s] was created" % (model.name()))
if opt.isTrain and len(opt.gpu_ids):
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)