diff options
| author | Boris Fomitchev <bfomitchev@nvidia.com> | 2018-05-08 20:18:10 -0700 |
|---|---|---|
| committer | Boris Fomitchev <bfomitchev@nvidia.com> | 2018-05-08 20:18:10 -0700 |
| commit | 25e205604e7eafa83867a15cfda526461fe58455 (patch) | |
| tree | fcce10851fb0d1627b60cc23100659506f1462bb /models | |
| parent | 4ca6b1610f9fa65f8bd7d7c15059bfde18a2f02a (diff) | |
ONNX export is working
Diffstat (limited to 'models')
| -rwxr-xr-x | models/models.py | 7 | ||||
| -rwxr-xr-x | models/pix2pixHD_model.py | 7 |
2 files changed, 12 insertions, 2 deletions
diff --git a/models/models.py b/models/models.py index 805696f..8e72e46 100755 --- a/models/models.py +++ b/models/models.py @@ -4,8 +4,11 @@ import torch def create_model(opt): if opt.model == 'pix2pixHD': - from .pix2pixHD_model import Pix2PixHDModel - model = Pix2PixHDModel() + from .pix2pixHD_model import Pix2PixHDModel, InferenceModel + if opt.isTrain: + model = Pix2PixHDModel() + else: + model = InferenceModel() else: from .ui_model import UIModel model = UIModel() diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py index de594ab..631a10f 100755 --- a/models/pix2pixHD_model.py +++ b/models/pix2pixHD_model.py @@ -270,3 +270,10 @@ class Pix2PixHDModel(BaseModel): if self.opt.verbose: print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr + +class InferenceModel(Pix2PixHDModel): + def forward(self, inp): + label, inst = inp + return self.inference(label, inst) + + |
