diff options
| author | Jun-Yan Zhu <junyanz@users.noreply.github.com> | 2017-11-02 15:29:06 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-11-02 15:29:06 -0700 |
| commit | c9f1962642c63772cd182f93e81908c90ba267f9 (patch) | |
| tree | 0d017c6d30c6c173d7d6b54f0283c364dc27451d /models/cycle_gan_model.py | |
| parent | 9d1bc76e6a4f791a25db1179c7c2b4c62a8d55cd (diff) | |
| parent | d6cb5036a2a8b57f8c1a7cdc7e7cc80416d4ee78 (diff) | |
Merge pull request #140 from jpmerc/patch-1
gpu memory leaks
Diffstat (limited to 'models/cycle_gan_model.py')
| -rw-r--r-- | models/cycle_gan_model.py | 81 |
1 files changed, 53 insertions, 28 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 29389db..ecb92dc 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -130,33 +130,58 @@ class CycleGANModel(BaseModel): # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. - self.idt_A = self.netG_A.forward(self.real_B) - self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt + idt_A = self.netG_A.forward(self.real_B) + loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. - self.idt_B = self.netG_B.forward(self.real_A) - self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt + idt_B = self.netG_B.forward(self.real_A) + loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt + + self.idt_A = idt_A.data + self.idt_B = idt_B.data + self.loss_idt_A = loss_idt_A.data[0] + self.loss_idt_B = loss_idt_B.data[0] + else: + loss_idt_A = 0 + loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss # D_A(G_A(A)) - self.fake_B = self.netG_A.forward(self.real_A) - pred_fake = self.netD_A.forward(self.fake_B) - self.loss_G_A = self.criterionGAN(pred_fake, True) + fake_B = self.netG_A.forward(self.real_A) + pred_fake = self.netD_A.forward(fake_B) + loss_G_A = self.criterionGAN(pred_fake, True) + # D_B(G_B(B)) - self.fake_A = self.netG_B.forward(self.real_B) - pred_fake = self.netD_B.forward(self.fake_A) - self.loss_G_B = self.criterionGAN(pred_fake, True) + fake_A = self.netG_B.forward(self.real_B) + pred_fake = self.netD_B.forward(fake_A) + loss_G_B = self.criterionGAN(pred_fake, True) + # Forward cycle loss - self.rec_A = self.netG_B.forward(self.fake_B) - self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A + rec_A = self.netG_B.forward(fake_B) + loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A + # Backward cycle loss - self.rec_B = self.netG_A.forward(self.fake_A) - self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B + rec_B = self.netG_A.forward(fake_A) + loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B + # combined loss - self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B - self.loss_G.backward() + loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B + loss_G.backward() + + self.fake_B = fake_B.data + self.fake_A = fake_A.data + self.rec_A = rec_A.data + self.rec_B = rec_B.data + + self.loss_G_A = loss_G_A.data[0] + self.loss_G_B = loss_G_B.data[0] + self.loss_cycle_A = loss_cycle_A.data[0] + self.loss_cycle_B = loss_cycle_B.data[0] + + + def optimize_parameters(self): # forward @@ -176,14 +201,14 @@ class CycleGANModel(BaseModel): def get_current_errors(self): D_A = self.loss_D_A.data[0] - G_A = self.loss_G_A.data[0] - Cyc_A = self.loss_cycle_A.data[0] + G_A = self.loss_G_A + Cyc_A = self.loss_cycle_A D_B = self.loss_D_B.data[0] - G_B = self.loss_G_B.data[0] - Cyc_B = self.loss_cycle_B.data[0] + G_B = self.loss_G_B + Cyc_B = self.loss_cycle_B if self.opt.identity > 0.0: - idt_A = self.loss_idt_A.data[0] - idt_B = self.loss_idt_B.data[0] + idt_A = self.loss_idt_A + idt_B = self.loss_idt_B return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) else: @@ -192,14 +217,14 @@ class CycleGANModel(BaseModel): def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) - fake_B = util.tensor2im(self.fake_B.data) - rec_A = util.tensor2im(self.rec_A.data) + fake_B = util.tensor2im(self.fake_B) + rec_A = util.tensor2im(self.rec_A) real_B = util.tensor2im(self.real_B.data) - fake_A = util.tensor2im(self.fake_A.data) - rec_B = util.tensor2im(self.rec_B.data) + fake_A = util.tensor2im(self.fake_A) + rec_B = util.tensor2im(self.rec_B) if self.opt.isTrain and self.opt.identity > 0.0: - idt_A = util.tensor2im(self.idt_A.data) - idt_B = util.tensor2im(self.idt_B.data) + idt_A = util.tensor2im(self.idt_A) + idt_B = util.tensor2im(self.idt_B) return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) else: |
