summaryrefslogtreecommitdiff
path: root/models/pix2pix_model.py
diff options
context:
space:
mode:
authorSsnL <tongzhou.wang.1994@gmail.com>2017-11-09 16:08:30 -0500
committerSsnL <tongzhou.wang.1994@gmail.com>2017-11-09 16:15:05 -0500
commitc2fc8d442f1248231eab4b73e111665288b1e615 (patch)
tree9621879f1070cf1d99829fa020e87000f878a3fa /models/pix2pix_model.py
parenta24e24d67d88f75869f447690f7d994fe7d42e2d (diff)
update
Diffstat (limited to 'models/pix2pix_model.py')
-rw-r--r--models/pix2pix_model.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 388a8d3..56adfc1 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -70,13 +70,13 @@ class Pix2PixModel(BaseModel):
def forward(self):
self.real_A = Variable(self.input_A)
- self.fake_B = self.netG.forward(self.real_A)
+ self.fake_B = self.netG(self.real_A)
self.real_B = Variable(self.input_B)
# no backprop gradients
def test(self):
self.real_A = Variable(self.input_A, volatile=True)
- self.fake_B = self.netG.forward(self.real_A)
+ self.fake_B = self.netG(self.real_A)
self.real_B = Variable(self.input_B, volatile=True)
# get image paths
@@ -87,12 +87,12 @@ class Pix2PixModel(BaseModel):
# Fake
# stop backprop to the generator by detaching fake_B
fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data)
- pred_fake = self.netD.forward(fake_AB.detach())
+ pred_fake = self.netD(fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real
real_AB = torch.cat((self.real_A, self.real_B), 1)
- pred_real = self.netD.forward(real_AB)
+ pred_real = self.netD(real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True)
# Combined loss
@@ -103,7 +103,7 @@ class Pix2PixModel(BaseModel):
def backward_G(self):
# First, G(A) should fake the discriminator
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
- pred_fake = self.netD.forward(fake_AB)
+ pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B