diff options
| author | Taesung Park <taesung_park@berkeley.edu> | 2017-11-08 11:37:42 -0800 |
|---|---|---|
| committer | Taesung Park <taesung_park@berkeley.edu> | 2017-11-08 11:37:42 -0800 |
| commit | b6f5966eb8224dfc7be68b1b67a87f006e42730d (patch) | |
| tree | c34a3f4746f1453fd83aeeeb2135eb7b6f0afb63 /models/cycle_gan_model.py | |
| parent | 5e0f7d6980ed1a1aaac8593351028d320e5f0a94 (diff) | |
working version with handwritten GAN loss. Shift value can be changed
Diffstat (limited to 'models/cycle_gan_model.py')
| -rw-r--r-- | models/cycle_gan_model.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 29389db..74771cf 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -44,9 +44,9 @@ class CycleGANModel(BaseModel): which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) - if self.isTrain: - self.load_network(self.netD_A, 'D_A', which_epoch) - self.load_network(self.netD_B, 'D_B', which_epoch) + #if self.isTrain: + # self.load_network(self.netD_A, 'D_A', which_epoch) + # self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr @@ -77,6 +77,8 @@ class CycleGANModel(BaseModel): networks.print_network(self.netD_B) print('-----------------------------------------------') + self.step_count = 0 + def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] @@ -84,6 +86,7 @@ class CycleGANModel(BaseModel): self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] + self.image_paths2 = input['B_paths' if AtoB else 'A_paths'] def forward(self): self.real_A = Variable(self.input_A) @@ -138,7 +141,7 @@ class CycleGANModel(BaseModel): else: self.loss_idt_A = 0 self.loss_idt_B = 0 - + # GAN loss # D_A(G_A(A)) self.fake_B = self.netG_A.forward(self.real_A) @@ -148,6 +151,7 @@ class CycleGANModel(BaseModel): self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) self.loss_G_B = self.criterionGAN(pred_fake, True) + # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A @@ -155,15 +159,25 @@ class CycleGANModel(BaseModel): self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss - self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): + self.step_count += 1 # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() + if (self.loss_G != self.loss_G).sum().data[0] > 0: + exit(1) + #for w in self.netG_A.parameters(): + #print(w.grad.data) + # if (w.grad.data != w.grad.data).sum() > 0: + # print(w.grad.data) + # exit(1) + #print(self.image_paths, self.image_paths2) + #return self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() |
