diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/cycle_gan_model.py | 24 | ||||
| -rw-r--r-- | models/networks.py | 31 |
2 files changed, 50 insertions, 5 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 29389db..74771cf 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -44,9 +44,9 @@ class CycleGANModel(BaseModel): which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) - if self.isTrain: - self.load_network(self.netD_A, 'D_A', which_epoch) - self.load_network(self.netD_B, 'D_B', which_epoch) + #if self.isTrain: + # self.load_network(self.netD_A, 'D_A', which_epoch) + # self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr @@ -77,6 +77,8 @@ class CycleGANModel(BaseModel): networks.print_network(self.netD_B) print('-----------------------------------------------') + self.step_count = 0 + def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] @@ -84,6 +86,7 @@ class CycleGANModel(BaseModel): self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] + self.image_paths2 = input['B_paths' if AtoB else 'A_paths'] def forward(self): self.real_A = Variable(self.input_A) @@ -138,7 +141,7 @@ class CycleGANModel(BaseModel): else: self.loss_idt_A = 0 self.loss_idt_B = 0 - + # GAN loss # D_A(G_A(A)) self.fake_B = self.netG_A.forward(self.real_A) @@ -148,6 +151,7 @@ class CycleGANModel(BaseModel): self.fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(self.fake_A) self.loss_G_B = self.criterionGAN(pred_fake, True) + # Forward cycle loss self.rec_A = self.netG_B.forward(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A @@ -155,15 +159,25 @@ class CycleGANModel(BaseModel): self.rec_B = self.netG_A.forward(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss - self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): + self.step_count += 1 # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() + if (self.loss_G != self.loss_G).sum().data[0] > 0: + exit(1) + #for w in self.netG_A.parameters(): + #print(w.grad.data) + # if (w.grad.data != w.grad.data).sum() > 0: + # print(w.grad.data) + # exit(1) + #print(self.image_paths, self.image_paths2) + #return self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() diff --git a/models/networks.py b/models/networks.py index 2df58fe..965bacb 100644 --- a/models/networks.py +++ b/models/networks.py @@ -136,6 +136,8 @@ def define_D(input_nc, ndf, which_model_netD, netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'n_layers': netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'pixel': + netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) @@ -432,3 +434,32 @@ class NLayerDiscriminator(nn.Module): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) + +class PixelDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): + super(PixelDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + if use_sigmoid: + self.net.append(nn.Sigmoid()) + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): + return nn.parallel.data_parallel(self.net, input, self.gpu_ids) + else: + return self.net(input) + |
