diff options
Diffstat (limited to 'models/models.py')
| -rw-r--r-- | models/models.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/models/models.py b/models/models.py index 7e790d0..8fea4f4 100644 --- a/models/models.py +++ b/models/models.py @@ -4,12 +4,17 @@ def create_model(opt): print(opt.model) if opt.model == 'cycle_gan': from .cycle_gan_model import CycleGANModel - assert(opt.align_data == False) + #assert(opt.align_data == False) model = CycleGANModel() - if opt.model == 'pix2pix': + elif opt.model == 'pix2pix': from .pix2pix_model import Pix2PixModel assert(opt.align_data == True) model = Pix2PixModel() + elif opt.model == 'one_direction_test': + from .one_direction_test_model import OneDirectionTestModel + model = OneDirectionTestModel() + else: + raise ValueError("Model [%s] not recognized." % opt.model) model.initialize(opt) print("model [%s] was created" % (model.name())) return model |
