diff options
| author | Taesung Park <taesung_park@berkeley.edu> | 2017-12-10 23:04:41 -0800 |
|---|---|---|
| committer | Taesung Park <taesung_park@berkeley.edu> | 2017-12-10 23:04:41 -0800 |
| commit | f33f098be9b25c3b62523540c9c703af1db0b1c0 (patch) | |
| tree | 9b51e547067b46ad8b55ddb34b207825550df867 /models | |
| parent | 3d2c534933b356dc313a620639a713cb940dc756 (diff) | |
| parent | 2d96edbee5a488a7861833731a2cb71b23b55727 (diff) | |
merged conflicts
Diffstat (limited to 'models')
| -rw-r--r-- | models/base_model.py | 3 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 143 | ||||
| -rw-r--r-- | models/networks.py | 23 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 16 | ||||
| -rw-r--r-- | models/test_model.py | 2 |
5 files changed, 93 insertions, 94 deletions
diff --git a/models/base_model.py b/models/base_model.py index 446a903..9b55afe 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -44,13 +44,14 @@ class BaseModel(): save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): - network.cuda(device_id=gpu_ids[0]) + network.cuda(gpu_ids[0]) # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) network.load_state_dict(torch.load(save_path)) + # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 74771cf..fe06823 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -44,9 +44,9 @@ class CycleGANModel(BaseModel): which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) - #if self.isTrain: - # self.load_network(self.netD_A, 'D_A', which_epoch) - # self.load_network(self.netD_B, 'D_B', which_epoch) + if self.isTrain: + self.load_network(self.netD_A, 'D_A', which_epoch) + self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr @@ -77,8 +77,6 @@ class CycleGANModel(BaseModel): networks.print_network(self.netD_B) print('-----------------------------------------------') - self.step_count = 0 - def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] @@ -86,20 +84,21 @@ class CycleGANModel(BaseModel): self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] - self.image_paths2 = input['B_paths' if AtoB else 'A_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): - self.real_A = Variable(self.input_A, volatile=True) - self.fake_B = self.netG_A.forward(self.real_A) - self.rec_A = self.netG_B.forward(self.fake_B) + real_A = Variable(self.input_A, volatile=True) + fake_B = self.netG_A(real_A) + self.rec_A = self.netG_B(fake_B).data + self.fake_B = fake_B.data - self.real_B = Variable(self.input_B, volatile=True) - self.fake_A = self.netG_B.forward(self.real_B) - self.rec_B = self.netG_A.forward(self.fake_A) + real_B = Variable(self.input_B, volatile=True) + fake_A = self.netG_B(real_B) + self.rec_B = self.netG_A(fake_A).data + self.fake_A = fake_A.data # get image paths def get_image_paths(self): @@ -107,10 +106,10 @@ class CycleGANModel(BaseModel): def backward_D_basic(self, netD, real, fake): # Real - pred_real = netD.forward(real) + pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake - pred_fake = netD.forward(fake.detach()) + pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 @@ -120,11 +119,13 @@ class CycleGANModel(BaseModel): def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) - self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + self.loss_D_A = loss_D_A.data[0] def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) - self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + self.loss_D_B = loss_D_B.data[0] def backward_G(self): lambda_idt = self.opt.identity @@ -133,51 +134,59 @@ class CycleGANModel(BaseModel): # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. - self.idt_A = self.netG_A.forward(self.real_B) - self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt + idt_A = self.netG_A(self.real_B) + loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. - self.idt_B = self.netG_B.forward(self.real_A) - self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt + idt_B = self.netG_B(self.real_A) + loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt + + self.idt_A = idt_A.data + self.idt_B = idt_B.data + self.loss_idt_A = loss_idt_A.data[0] + self.loss_idt_B = loss_idt_B.data[0] else: + loss_idt_A = 0 + loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 - - # GAN loss - # D_A(G_A(A)) - self.fake_B = self.netG_A.forward(self.real_A) - pred_fake = self.netD_A.forward(self.fake_B) - self.loss_G_A = self.criterionGAN(pred_fake, True) - # D_B(G_B(B)) - self.fake_A = self.netG_B.forward(self.real_B) - pred_fake = self.netD_B.forward(self.fake_A) - self.loss_G_B = self.criterionGAN(pred_fake, True) + + # GAN loss D_A(G_A(A)) + fake_B = self.netG_A(self.real_A) + pred_fake = self.netD_A(fake_B) + loss_G_A = self.criterionGAN(pred_fake, True) + + # GAN loss D_B(G_B(B)) + fake_A = self.netG_B(self.real_B) + pred_fake = self.netD_B(fake_A) + loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss - self.rec_A = self.netG_B.forward(self.fake_B) - self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A + rec_A = self.netG_B(fake_B) + loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A + # Backward cycle loss - self.rec_B = self.netG_A.forward(self.fake_A) - self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B + rec_B = self.netG_A(fake_A) + loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B # combined loss - self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B - self.loss_G.backward() + loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B + loss_G.backward() + + self.fake_B = fake_B.data + self.fake_A = fake_A.data + self.rec_A = rec_A.data + self.rec_B = rec_B.data + + self.loss_G_A = loss_G_A.data[0] + self.loss_G_B = loss_G_B.data[0] + self.loss_cycle_A = loss_cycle_A.data[0] + self.loss_cycle_B = loss_cycle_B.data[0] def optimize_parameters(self): - self.step_count += 1 # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() - if (self.loss_G != self.loss_G).sum().data[0] > 0: - exit(1) - #for w in self.netG_A.parameters(): - #print(w.grad.data) - # if (w.grad.data != w.grad.data).sum() > 0: - # print(w.grad.data) - # exit(1) - #print(self.image_paths, self.image_paths2) - #return self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() @@ -189,36 +198,26 @@ class CycleGANModel(BaseModel): self.optimizer_D_B.step() def get_current_errors(self): - D_A = self.loss_D_A.data[0] - G_A = self.loss_G_A.data[0] - Cyc_A = self.loss_cycle_A.data[0] - D_B = self.loss_D_B.data[0] - G_B = self.loss_G_B.data[0] - Cyc_B = self.loss_cycle_B.data[0] + ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), + ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.identity > 0.0: - idt_A = self.loss_idt_A.data[0] - idt_B = self.loss_idt_B.data[0] - return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), - ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) - else: - return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), - ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) + ret_errors['idt_A'] = self.loss_idt_A + ret_errors['idt_B'] = self.loss_idt_B + return ret_errors def get_current_visuals(self): - real_A = util.tensor2im(self.real_A.data) - fake_B = util.tensor2im(self.fake_B.data) - rec_A = util.tensor2im(self.rec_A.data) - real_B = util.tensor2im(self.real_B.data) - fake_A = util.tensor2im(self.fake_A.data) - rec_B = util.tensor2im(self.rec_B.data) + real_A = util.tensor2im(self.input_A) + fake_B = util.tensor2im(self.fake_B) + rec_A = util.tensor2im(self.rec_A) + real_B = util.tensor2im(self.input_B) + fake_A = util.tensor2im(self.fake_A) + rec_B = util.tensor2im(self.rec_B) + ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.identity > 0.0: - idt_A = util.tensor2im(self.idt_A.data) - idt_B = util.tensor2im(self.idt_B.data) - return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), - ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) - else: - return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), - ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) + ret_visuals['idt_A'] = util.tensor2im(self.idt_A) + ret_visuals['idt_B'] = util.tensor2im(self.idt_B) + return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) diff --git a/models/networks.py b/models/networks.py index 965bacb..568f8c9 100644 --- a/models/networks.py +++ b/models/networks.py @@ -10,16 +10,15 @@ import numpy as np ############################################################################### - def weights_init_normal(m): classname = m.__class__.__name__ # print(classname) if classname.find('Conv') != -1: - init.uniform(m.weight.data, 0.0, 0.02) + init.normal(m.weight.data, 0.0, 0.02) elif classname.find('Linear') != -1: - init.uniform(m.weight.data, 0.0, 0.02) + init.normal(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm2d') != -1: - init.uniform(m.weight.data, 1.0, 0.02) + init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) @@ -27,11 +26,11 @@ def weights_init_xavier(m): classname = m.__class__.__name__ # print(classname) if classname.find('Conv') != -1: - init.xavier_normal(m.weight.data, gain=1) + init.xavier_normal(m.weight.data, gain=0.02) elif classname.find('Linear') != -1: - init.xavier_normal(m.weight.data, gain=1) + init.xavier_normal(m.weight.data, gain=0.02) elif classname.find('BatchNorm2d') != -1: - init.uniform(m.weight.data, 1.0, 0.02) + init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) @@ -43,7 +42,7 @@ def weights_init_kaiming(m): elif classname.find('Linear') != -1: init.kaiming_normal(m.weight.data, a=0, mode='fan_in') elif classname.find('BatchNorm2d') != -1: - init.uniform(m.weight.data, 1.0, 0.02) + init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) @@ -55,7 +54,7 @@ def weights_init_orthogonal(m): elif classname.find('Linear') != -1: init.orthogonal(m.weight.data, gain=1) elif classname.find('BatchNorm2d') != -1: - init.uniform(m.weight.data, 1.0, 0.02) + init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) @@ -88,7 +87,7 @@ def get_norm_layer(norm_type='instance'): def get_scheduler(optimizer, opt): if opt.lr_policy == 'lambda': def lambda_rule(epoch): - lr_l = 1.0 - max(0, epoch - opt.niter) / float(opt.niter_decay+1) + lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) return lr_l scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif opt.lr_policy == 'step': @@ -119,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: - netG.cuda(device_id=gpu_ids[0]) + netG.cuda(gpu_ids[0]) init_weights(netG, init_type=init_type) return netG @@ -142,7 +141,7 @@ def define_D(input_nc, ndf, which_model_netD, raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: - netD.cuda(device_id=gpu_ids[0]) + netD.cuda(gpu_ids[0]) init_weights(netD, init_type=init_type) return netD diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 18ba53f..56adfc1 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -70,13 +70,13 @@ class Pix2PixModel(BaseModel): def forward(self): self.real_A = Variable(self.input_A) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths @@ -86,14 +86,14 @@ class Pix2PixModel(BaseModel): def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B - fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) - self.pred_fake = self.netD.forward(fake_AB.detach()) - self.loss_D_fake = self.criterionGAN(self.pred_fake, False) + fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data) + pred_fake = self.netD(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) - self.pred_real = self.netD.forward(real_AB) - self.loss_D_real = self.criterionGAN(self.pred_real, True) + pred_real = self.netD(real_AB) + self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 @@ -103,7 +103,7 @@ class Pix2PixModel(BaseModel): def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) - pred_fake = self.netD.forward(fake_AB) + pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B diff --git a/models/test_model.py b/models/test_model.py index 4af1fe1..2ae2812 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -34,7 +34,7 @@ class TestModel(BaseModel): def test(self): self.real_A = Variable(self.input_A) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) # get image paths def get_image_paths(self): |
