summaryrefslogtreecommitdiff
path: root/models/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/models.py')
-rw-r--r--models/models.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/models/models.py b/models/models.py
index 8fea4f4..efcd898 100644
--- a/models/models.py
+++ b/models/models.py
@@ -3,15 +3,16 @@ def create_model(opt):
model = None
print(opt.model)
if opt.model == 'cycle_gan':
+ assert(opt.dataset_mode == 'unaligned')
from .cycle_gan_model import CycleGANModel
- #assert(opt.align_data == False)
model = CycleGANModel()
elif opt.model == 'pix2pix':
+ assert(opt.dataset_mode == 'aligned')
from .pix2pix_model import Pix2PixModel
- assert(opt.align_data == True)
model = Pix2PixModel()
- elif opt.model == 'one_direction_test':
- from .one_direction_test_model import OneDirectionTestModel
+ elif opt.model == 'test':
+ assert(opt.dataset_mode == 'single')
+ from .test_model import TestModel
model = OneDirectionTestModel()
else:
raise ValueError("Model [%s] not recognized." % opt.model)