summaryrefslogtreecommitdiff
path: root/models/models.py
blob: 7e790d0ec40026dd957be971275faff4385a6598 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def create_model(opt):
    model = None
    print(opt.model)
    if opt.model == 'cycle_gan':
        from .cycle_gan_model import CycleGANModel
        assert(opt.align_data == False)
        model = CycleGANModel()
    if opt.model == 'pix2pix':
        from .pix2pix_model import Pix2PixModel
        assert(opt.align_data == True)
        model = Pix2PixModel()
    model.initialize(opt)
    print("model [%s] was created" % (model.name()))
    return model