diff options
| author | junyanz <junyanz@mit.edu> | 2018-02-18 22:53:26 -0500 |
|---|---|---|
| committer | junyanz <junyanz@mit.edu> | 2018-02-18 22:53:26 -0500 |
| commit | 079da5c02fd99ef35d7cad0e20c2924b7c2bcffd (patch) | |
| tree | 8d00061c5b639f0ad832ab2a34a76d10a076dff7 /models | |
| parent | 51bd910bc737984163f6d6534cccbd22668e2b28 (diff) | |
fix test_model & add timer for data loader & rename identity loss
Diffstat (limited to 'models')
| -rw-r--r-- | models/cycle_gan_model.py | 6 | ||||
| -rw-r--r-- | models/test_model.py | 2 |
2 files changed, 4 insertions, 4 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 730d077..85432bb 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -121,7 +121,7 @@ class CycleGANModel(BaseModel): self.loss_D_B = loss_D_B.data[0] def backward_G(self): - lambda_idt = self.opt.identity + lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss @@ -193,7 +193,7 @@ class CycleGANModel(BaseModel): def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) - if self.opt.identity > 0.0: + if self.opt.lambda_identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors @@ -207,7 +207,7 @@ class CycleGANModel(BaseModel): rec_B = util.tensor2im(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) - if self.opt.isTrain and self.opt.identity > 0.0: + if self.opt.isTrain and self.opt.lambda_identity > 0.0: ret_visuals['idt_A'] = util.tensor2im(self.idt_A) ret_visuals['idt_B'] = util.tensor2im(self.idt_B) return ret_visuals diff --git a/models/test_model.py b/models/test_model.py index 35b7402..f593c46 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -28,7 +28,7 @@ class TestModel(BaseModel): # we need to use single_dataset mode input_A = input['A'] if len(self.gpu_ids) > 0: - input_A.cuda(self.gpu_ids[0], async=True) + input_A = input_A.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.image_paths = input['A_paths'] |
