summaryrefslogtreecommitdiff
path: root/models/pix2pix_model.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-11-04 02:47:39 -0700
committerjunyanz <junyanz@berkeley.edu>2017-11-04 02:47:39 -0700
commit7a9021d4f131ee059d49ff9b2d135e6543f75763 (patch)
treeb2ec50907d242f7bbaeeafb2ff381ceb0d2d071d /models/pix2pix_model.py
parent6b8e96c4bbd73a1e1d4e126d795a26fd0dae983c (diff)
fix small issues
Diffstat (limited to 'models/pix2pix_model.py')
-rw-r--r--models/pix2pix_model.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 8cd494f..388a8d3 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -86,14 +86,14 @@ class Pix2PixModel(BaseModel):
def backward_D(self):
# Fake
# stop backprop to the generator by detaching fake_B
- fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
+ fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data)
pred_fake = self.netD.forward(fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real
real_AB = torch.cat((self.real_A, self.real_B), 1)
pred_real = self.netD.forward(real_AB)
- self.loss_D_real = self.criterionGAN(self.pred_real, True)
+ self.loss_D_real = self.criterionGAN(pred_real, True)
# Combined loss
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5