diff options
| author | SsnL <tongzhou.wang.1994@gmail.com> | 2017-11-09 16:08:30 -0500 |
|---|---|---|
| committer | SsnL <tongzhou.wang.1994@gmail.com> | 2017-11-09 16:15:05 -0500 |
| commit | c2fc8d442f1248231eab4b73e111665288b1e615 (patch) | |
| tree | 9621879f1070cf1d99829fa020e87000f878a3fa /models/pix2pix_model.py | |
| parent | a24e24d67d88f75869f447690f7d994fe7d42e2d (diff) | |
update
Diffstat (limited to 'models/pix2pix_model.py')
| -rw-r--r-- | models/pix2pix_model.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 388a8d3..56adfc1 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -70,13 +70,13 @@ class Pix2PixModel(BaseModel): def forward(self): self.real_A = Variable(self.input_A) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths @@ -87,12 +87,12 @@ class Pix2PixModel(BaseModel): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data) - pred_fake = self.netD.forward(fake_AB.detach()) + pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) - pred_real = self.netD.forward(real_AB) + pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss @@ -103,7 +103,7 @@ class Pix2PixModel(BaseModel): def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) - pred_fake = self.netD.forward(fake_AB) + pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B |
