summaryrefslogtreecommitdiff
path: root/models/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/models.py')
-rw-r--r--models/models.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/models/models.py b/models/models.py
index 7e790d0..8fea4f4 100644
--- a/models/models.py
+++ b/models/models.py
@@ -4,12 +4,17 @@ def create_model(opt):
print(opt.model)
if opt.model == 'cycle_gan':
from .cycle_gan_model import CycleGANModel
- assert(opt.align_data == False)
+ #assert(opt.align_data == False)
model = CycleGANModel()
- if opt.model == 'pix2pix':
+ elif opt.model == 'pix2pix':
from .pix2pix_model import Pix2PixModel
assert(opt.align_data == True)
model = Pix2PixModel()
+ elif opt.model == 'one_direction_test':
+ from .one_direction_test_model import OneDirectionTestModel
+ model = OneDirectionTestModel()
+ else:
+ raise ValueError("Model [%s] not recognized." % opt.model)
model.initialize(opt)
print("model [%s] was created" % (model.name()))
return model