summaryrefslogtreecommitdiff
path: root/models/cycle_gan_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/cycle_gan_model.py')
-rw-r--r--models/cycle_gan_model.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index 730d077..85432bb 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -121,7 +121,7 @@ class CycleGANModel(BaseModel):
self.loss_D_B = loss_D_B.data[0]
def backward_G(self):
- lambda_idt = self.opt.identity
+ lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
@@ -193,7 +193,7 @@ class CycleGANModel(BaseModel):
def get_current_errors(self):
ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
- if self.opt.identity > 0.0:
+ if self.opt.lambda_identity > 0.0:
ret_errors['idt_A'] = self.loss_idt_A
ret_errors['idt_B'] = self.loss_idt_B
return ret_errors
@@ -207,7 +207,7 @@ class CycleGANModel(BaseModel):
rec_B = util.tensor2im(self.rec_B)
ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
- if self.opt.isTrain and self.opt.identity > 0.0:
+ if self.opt.isTrain and self.opt.lambda_identity > 0.0:
ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
return ret_visuals