summaryrefslogtreecommitdiff
path: root/models/models.py
diff options
context:
space:
mode:
authortingchunw <tingchunw@nvidia.com>2018-01-30 01:30:18 +0000
committertingchunw <tingchunw@nvidia.com>2018-01-30 01:30:18 +0000
commitedf910b1c1d02020b31782ab4c3b6ebf9af8c323 (patch)
treecdf1431710eaac3e13f4fc5aa899101dea94c736 /models/models.py
parentdd05da797863a13f6e45ec0a2d2ff3c7f8142f38 (diff)
change dataset naming convention and add ui model
Diffstat (limited to 'models/models.py')
-rwxr-xr-xmodels/models.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/models/models.py b/models/models.py
index 351483c..0ba442f 100755
--- a/models/models.py
+++ b/models/models.py
@@ -3,8 +3,12 @@
import torch
def create_model(opt):
- from .pix2pixHD_model import Pix2PixHDModel
- model = Pix2PixHDModel()
+ if opt.model == 'pix2pixHD':
+ from .pix2pixHD_model import Pix2PixHDModel
+ model = Pix2PixHDModel()
+ else:
+ from .ui_model import UIModel
+ model = UIModel()
model.initialize(opt)
print("model [%s] was created" % (model.name()))