diff options
Diffstat (limited to 'models/pix2pix_model.py')
| -rw-r--r-- | models/pix2pix_model.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 18ba53f..56adfc1 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -70,13 +70,13 @@ class Pix2PixModel(BaseModel): def forward(self): self.real_A = Variable(self.input_A) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths @@ -86,14 +86,14 @@ class Pix2PixModel(BaseModel): def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B - fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) - self.pred_fake = self.netD.forward(fake_AB.detach()) - self.loss_D_fake = self.criterionGAN(self.pred_fake, False) + fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data) + pred_fake = self.netD(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) - self.pred_real = self.netD.forward(real_AB) - self.loss_D_real = self.criterionGAN(self.pred_real, True) + pred_real = self.netD(real_AB) + self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 @@ -103,7 +103,7 @@ class Pix2PixModel(BaseModel): def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) - pred_fake = self.netD.forward(fake_AB) + pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B |
