diff options
| author | SsnL <tongzhou.wang.1994@gmail.com> | 2017-11-09 16:08:30 -0500 |
|---|---|---|
| committer | SsnL <tongzhou.wang.1994@gmail.com> | 2017-11-09 16:15:05 -0500 |
| commit | c2fc8d442f1248231eab4b73e111665288b1e615 (patch) | |
| tree | 9621879f1070cf1d99829fa020e87000f878a3fa /models/cycle_gan_model.py | |
| parent | a24e24d67d88f75869f447690f7d994fe7d42e2d (diff) | |
update
Diffstat (limited to 'models/cycle_gan_model.py')
| -rw-r--r-- | models/cycle_gan_model.py | 30 |
1 files changed, 14 insertions, 16 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index e840e7b..fe06823 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -91,13 +91,13 @@ class CycleGANModel(BaseModel): def test(self): real_A = Variable(self.input_A, volatile=True) - fake_B = self.netG_A.forward(real_A) - self.rec_A = self.netG_B.forward(fake_B).data + fake_B = self.netG_A(real_A) + self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_B, volatile=True) - fake_A = self.netG_B.forward(real_B) - self.rec_B = self.netG_A.forward(fake_A).data + fake_A = self.netG_B(real_B) + self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data # get image paths @@ -106,10 +106,10 @@ class CycleGANModel(BaseModel): def backward_D_basic(self, netD, real, fake): # Real - pred_real = netD.forward(real) + pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake - pred_fake = netD.forward(fake.detach()) + pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 @@ -134,17 +134,16 @@ class CycleGANModel(BaseModel): # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. - idt_A = self.netG_A.forward(self.real_B) + idt_A = self.netG_A(self.real_B) loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. - idt_B = self.netG_B.forward(self.real_A) + idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] self.loss_idt_B = loss_idt_B.data[0] - else: loss_idt_A = 0 loss_idt_B = 0 @@ -152,23 +151,22 @@ class CycleGANModel(BaseModel): self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) - fake_B = self.netG_A.forward(self.real_A) - pred_fake = self.netD_A.forward(fake_B) + fake_B = self.netG_A(self.real_A) + pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) - fake_A = self.netG_B.forward(self.real_B) - pred_fake = self.netD_B.forward(fake_A) + fake_A = self.netG_B(self.real_B) + pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss - rec_A = self.netG_B.forward(fake_B) + rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # Backward cycle loss - rec_B = self.netG_A.forward(fake_A) + rec_B = self.netG_A(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B - # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() |
