diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/base_model.py | 3 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 88 | ||||
| -rw-r--r-- | models/networks.py | 4 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 6 |
4 files changed, 46 insertions, 55 deletions
diff --git a/models/base_model.py b/models/base_model.py index d62d189..646a014 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -44,13 +44,14 @@ class BaseModel(): save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): - network.cuda(device_id=gpu_ids[0]) # network.cuda(device=gpu_ids[0]) for the latest version. + network.cuda(device_id=gpu_ids[0]) # network.cuda(device=gpu_ids[0]) for the latest version. # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) network.load_state_dict(torch.load(save_path)) + # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index ff7330b..71a447d 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -90,13 +90,15 @@ class CycleGANModel(BaseModel): self.real_B = Variable(self.input_B) def test(self): - self.real_A = Variable(self.input_A, volatile=True) - self.fake_B = self.netG_A.forward(self.real_A) - self.rec_A = self.netG_B.forward(self.fake_B) + real_A = Variable(self.input_A, volatile=True) + fake_B = self.netG_A.forward(real_A) + self.rec_A = self.netG_B.forward(fake_B).data + self.fake_B = fake_B.data - self.real_B = Variable(self.input_B, volatile=True) - self.fake_A = self.netG_B.forward(self.real_B) - self.rec_B = self.netG_A.forward(self.fake_A) + real_B = Variable(self.input_B, volatile=True) + fake_A = self.netG_B.forward(real_B) + self.rec_B = self.netG_A.forward(fake_A).data + self.fake_A = self.fake_A.data # get image paths def get_image_paths(self): @@ -117,11 +119,13 @@ class CycleGANModel(BaseModel): def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) - self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + self.loss_D_A = loss_D_A.data[0] def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) - self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + self.loss_D_B = loss_D_B.data[0] def backward_G(self): lambda_idt = self.opt.identity @@ -135,53 +139,49 @@ class CycleGANModel(BaseModel): # G_B should be identity if real_A is fed. idt_B = self.netG_B.forward(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt - + self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] - self.loss_idt_B = loss_idt_B.data[0] - + self.loss_idt_B = loss_idt_B.data[0] + else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 - # GAN loss - # D_A(G_A(A)) + # GAN loss D_A(G_A(A)) fake_B = self.netG_A.forward(self.real_A) pred_fake = self.netD_A.forward(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) - - # D_B(G_B(B)) + + # GAN loss D_B(G_B(B)) fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) - + # Forward cycle loss rec_A = self.netG_B.forward(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A - + # Backward cycle loss rec_B = self.netG_A.forward(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B - + # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() - + self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data - + self.loss_G_A = loss_G_A.data[0] self.loss_G_B = loss_G_B.data[0] self.loss_cycle_A = loss_cycle_A.data[0] self.loss_cycle_B = loss_cycle_B.data[0] - - - def optimize_parameters(self): # forward @@ -200,36 +200,26 @@ class CycleGANModel(BaseModel): self.optimizer_D_B.step() def get_current_errors(self): - D_A = self.loss_D_A.data[0] - G_A = self.loss_G_A - Cyc_A = self.loss_cycle_A - D_B = self.loss_D_B.data[0] - G_B = self.loss_G_B - Cyc_B = self.loss_cycle_B + ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), + ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.identity > 0.0: - idt_A = self.loss_idt_A - idt_B = self.loss_idt_B - return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), - ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) - else: - return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), - ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) + ret_errors['idt_A'] = self.loss_idt_A + ret_errors['idt_B'] = self.loss_idt_B + return ret_errors def get_current_visuals(self): - real_A = util.tensor2im(self.real_A.data) - fake_B = util.tensor2im(self.fake_B.data) - rec_A = util.tensor2im(self.rec_A.data) - real_B = util.tensor2im(self.real_B.data) - fake_A = util.tensor2im(self.fake_A.data) - rec_B = util.tensor2im(self.rec_B.data) + real_A = util.tensor2im(self.input_A) + fake_B = util.tensor2im(self.fake_B) + rec_A = util.tensor2im(self.rec_A) + real_B = util.tensor2im(self.input_B) + fake_A = util.tensor2im(self.fake_A) + rec_B = util.tensor2im(self.rec_B) + ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.identity > 0.0: - idt_A = util.tensor2im(self.idt_A) - idt_B = util.tensor2im(self.idt_B) - return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), - ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) - else: - return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), - ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) + ret_visuals['idt_A'] = util.tensor2im(self.idt_A) + ret_visuals['idt_B'] = util.tensor2im(self.idt_B) + return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) diff --git a/models/networks.py b/models/networks.py index 949659d..d071ac4 100644 --- a/models/networks.py +++ b/models/networks.py @@ -118,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: - netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. + netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. init_weights(netG, init_type=init_type) return netG @@ -139,7 +139,7 @@ def define_D(input_nc, ndf, which_model_netD, raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: - netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version. + netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version. init_weights(netD, init_type=init_type) return netD diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 18ba53f..8cd494f 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -87,12 +87,12 @@ class Pix2PixModel(BaseModel): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) - self.pred_fake = self.netD.forward(fake_AB.detach()) - self.loss_D_fake = self.criterionGAN(self.pred_fake, False) + pred_fake = self.netD.forward(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) - self.pred_real = self.netD.forward(real_AB) + pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss |
