summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/cycle_gan_model.py81
1 files changed, 53 insertions, 28 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index 29389db..ecb92dc 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -130,33 +130,58 @@ class CycleGANModel(BaseModel):
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed.
- self.idt_A = self.netG_A.forward(self.real_B)
- self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
+ idt_A = self.netG_A.forward(self.real_B)
+ loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed.
- self.idt_B = self.netG_B.forward(self.real_A)
- self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
+ idt_B = self.netG_B.forward(self.real_A)
+ loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt
+
+ self.idt_A = idt_A.data
+ self.idt_B = idt_B.data
+ self.loss_idt_A = loss_idt_A.data[0]
+ self.loss_idt_B = loss_idt_B.data[0]
+
else:
+ loss_idt_A = 0
+ loss_idt_B = 0
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss
# D_A(G_A(A))
- self.fake_B = self.netG_A.forward(self.real_A)
- pred_fake = self.netD_A.forward(self.fake_B)
- self.loss_G_A = self.criterionGAN(pred_fake, True)
+ fake_B = self.netG_A.forward(self.real_A)
+ pred_fake = self.netD_A.forward(fake_B)
+ loss_G_A = self.criterionGAN(pred_fake, True)
+
# D_B(G_B(B))
- self.fake_A = self.netG_B.forward(self.real_B)
- pred_fake = self.netD_B.forward(self.fake_A)
- self.loss_G_B = self.criterionGAN(pred_fake, True)
+ fake_A = self.netG_B.forward(self.real_B)
+ pred_fake = self.netD_B.forward(fake_A)
+ loss_G_B = self.criterionGAN(pred_fake, True)
+
# Forward cycle loss
- self.rec_A = self.netG_B.forward(self.fake_B)
- self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
+ rec_A = self.netG_B.forward(fake_B)
+ loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A
+
# Backward cycle loss
- self.rec_B = self.netG_A.forward(self.fake_A)
- self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
+ rec_B = self.netG_A.forward(fake_A)
+ loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
+
# combined loss
- self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
- self.loss_G.backward()
+ loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
+ loss_G.backward()
+
+ self.fake_B = fake_B.data
+ self.fake_A = fake_A.data
+ self.rec_A = rec_A.data
+ self.rec_B = rec_B.data
+
+ self.loss_G_A = loss_G_A.data[0]
+ self.loss_G_B = loss_G_B.data[0]
+ self.loss_cycle_A = loss_cycle_A.data[0]
+ self.loss_cycle_B = loss_cycle_B.data[0]
+
+
+
def optimize_parameters(self):
# forward
@@ -176,14 +201,14 @@ class CycleGANModel(BaseModel):
def get_current_errors(self):
D_A = self.loss_D_A.data[0]
- G_A = self.loss_G_A.data[0]
- Cyc_A = self.loss_cycle_A.data[0]
+ G_A = self.loss_G_A
+ Cyc_A = self.loss_cycle_A
D_B = self.loss_D_B.data[0]
- G_B = self.loss_G_B.data[0]
- Cyc_B = self.loss_cycle_B.data[0]
+ G_B = self.loss_G_B
+ Cyc_B = self.loss_cycle_B
if self.opt.identity > 0.0:
- idt_A = self.loss_idt_A.data[0]
- idt_B = self.loss_idt_B.data[0]
+ idt_A = self.loss_idt_A
+ idt_B = self.loss_idt_B
return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
else:
@@ -192,14 +217,14 @@ class CycleGANModel(BaseModel):
def get_current_visuals(self):
real_A = util.tensor2im(self.real_A.data)
- fake_B = util.tensor2im(self.fake_B.data)
- rec_A = util.tensor2im(self.rec_A.data)
+ fake_B = util.tensor2im(self.fake_B)
+ rec_A = util.tensor2im(self.rec_A)
real_B = util.tensor2im(self.real_B.data)
- fake_A = util.tensor2im(self.fake_A.data)
- rec_B = util.tensor2im(self.rec_B.data)
+ fake_A = util.tensor2im(self.fake_A)
+ rec_B = util.tensor2im(self.rec_B)
if self.opt.isTrain and self.opt.identity > 0.0:
- idt_A = util.tensor2im(self.idt_A.data)
- idt_B = util.tensor2im(self.idt_B.data)
+ idt_A = util.tensor2im(self.idt_A)
+ idt_B = util.tensor2im(self.idt_B)
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B),
('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)])
else: