summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/cycle_gan_model.py24
-rw-r--r--models/networks.py31
2 files changed, 50 insertions, 5 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index 29389db..74771cf 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -44,9 +44,9 @@ class CycleGANModel(BaseModel):
which_epoch = opt.which_epoch
self.load_network(self.netG_A, 'G_A', which_epoch)
self.load_network(self.netG_B, 'G_B', which_epoch)
- if self.isTrain:
- self.load_network(self.netD_A, 'D_A', which_epoch)
- self.load_network(self.netD_B, 'D_B', which_epoch)
+ #if self.isTrain:
+ # self.load_network(self.netD_A, 'D_A', which_epoch)
+ # self.load_network(self.netD_B, 'D_B', which_epoch)
if self.isTrain:
self.old_lr = opt.lr
@@ -77,6 +77,8 @@ class CycleGANModel(BaseModel):
networks.print_network(self.netD_B)
print('-----------------------------------------------')
+ self.step_count = 0
+
def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
@@ -84,6 +86,7 @@ class CycleGANModel(BaseModel):
self.input_A.resize_(input_A.size()).copy_(input_A)
self.input_B.resize_(input_B.size()).copy_(input_B)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
+ self.image_paths2 = input['B_paths' if AtoB else 'A_paths']
def forward(self):
self.real_A = Variable(self.input_A)
@@ -138,7 +141,7 @@ class CycleGANModel(BaseModel):
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
-
+
# GAN loss
# D_A(G_A(A))
self.fake_B = self.netG_A.forward(self.real_A)
@@ -148,6 +151,7 @@ class CycleGANModel(BaseModel):
self.fake_A = self.netG_B.forward(self.real_B)
pred_fake = self.netD_B.forward(self.fake_A)
self.loss_G_B = self.criterionGAN(pred_fake, True)
+
# Forward cycle loss
self.rec_A = self.netG_B.forward(self.fake_B)
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
@@ -155,15 +159,25 @@ class CycleGANModel(BaseModel):
self.rec_B = self.netG_A.forward(self.fake_A)
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss
- self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
+ self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()
def optimize_parameters(self):
+ self.step_count += 1
# forward
self.forward()
# G_A and G_B
self.optimizer_G.zero_grad()
self.backward_G()
+ if (self.loss_G != self.loss_G).sum().data[0] > 0:
+ exit(1)
+ #for w in self.netG_A.parameters():
+ #print(w.grad.data)
+ # if (w.grad.data != w.grad.data).sum() > 0:
+ # print(w.grad.data)
+ # exit(1)
+ #print(self.image_paths, self.image_paths2)
+ #return
self.optimizer_G.step()
# D_A
self.optimizer_D_A.zero_grad()
diff --git a/models/networks.py b/models/networks.py
index 2df58fe..965bacb 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -136,6 +136,8 @@ def define_D(input_nc, ndf, which_model_netD,
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
+ elif which_model_netD == 'pixel':
+ netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
@@ -432,3 +434,32 @@ class NLayerDiscriminator(nn.Module):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
+
+class PixelDiscriminator(nn.Module):
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
+ super(PixelDiscriminator, self).__init__()
+ self.gpu_ids = gpu_ids
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ self.net = [
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
+ norm_layer(ndf * 2),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
+
+ if use_sigmoid:
+ self.net.append(nn.Sigmoid())
+
+ self.net = nn.Sequential(*self.net)
+
+ def forward(self, input):
+ if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
+ return nn.parallel.data_parallel(self.net, input, self.gpu_ids)
+ else:
+ return self.net(input)
+