summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorBoris Fomitchev <bfomitchev@nvidia.com>2018-05-08 20:18:10 -0700
committerBoris Fomitchev <bfomitchev@nvidia.com>2018-05-08 20:18:10 -0700
commit25e205604e7eafa83867a15cfda526461fe58455 (patch)
treefcce10851fb0d1627b60cc23100659506f1462bb /models
parent4ca6b1610f9fa65f8bd7d7c15059bfde18a2f02a (diff)
ONNX export is working
Diffstat (limited to 'models')
-rwxr-xr-xmodels/models.py7
-rwxr-xr-xmodels/pix2pixHD_model.py7
2 files changed, 12 insertions, 2 deletions
diff --git a/models/models.py b/models/models.py
index 805696f..8e72e46 100755
--- a/models/models.py
+++ b/models/models.py
@@ -4,8 +4,11 @@ import torch
def create_model(opt):
if opt.model == 'pix2pixHD':
- from .pix2pixHD_model import Pix2PixHDModel
- model = Pix2PixHDModel()
+ from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
+ if opt.isTrain:
+ model = Pix2PixHDModel()
+ else:
+ model = InferenceModel()
else:
from .ui_model import UIModel
model = UIModel()
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py
index de594ab..631a10f 100755
--- a/models/pix2pixHD_model.py
+++ b/models/pix2pixHD_model.py
@@ -270,3 +270,10 @@ class Pix2PixHDModel(BaseModel):
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
+
+class InferenceModel(Pix2PixHDModel):
+ def forward(self, inp):
+ label, inst = inp
+ return self.inference(label, inst)
+
+