summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorTaesung Park <taesung_park@berkeley.edu>2017-12-10 23:04:41 -0800
committerTaesung Park <taesung_park@berkeley.edu>2017-12-10 23:04:41 -0800
commitf33f098be9b25c3b62523540c9c703af1db0b1c0 (patch)
tree9b51e547067b46ad8b55ddb34b207825550df867 /models
parent3d2c534933b356dc313a620639a713cb940dc756 (diff)
parent2d96edbee5a488a7861833731a2cb71b23b55727 (diff)
merged conflicts
Diffstat (limited to 'models')
-rw-r--r--models/base_model.py3
-rw-r--r--models/cycle_gan_model.py143
-rw-r--r--models/networks.py23
-rw-r--r--models/pix2pix_model.py16
-rw-r--r--models/test_model.py2
5 files changed, 93 insertions, 94 deletions
diff --git a/models/base_model.py b/models/base_model.py
index 446a903..9b55afe 100644
--- a/models/base_model.py
+++ b/models/base_model.py
@@ -44,13 +44,14 @@ class BaseModel():
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
- network.cuda(device_id=gpu_ids[0])
+ network.cuda(gpu_ids[0])
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
network.load_state_dict(torch.load(save_path))
+
# update learning rate (called once every epoch)
def update_learning_rate(self):
for scheduler in self.schedulers:
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index 74771cf..fe06823 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -44,9 +44,9 @@ class CycleGANModel(BaseModel):
which_epoch = opt.which_epoch
self.load_network(self.netG_A, 'G_A', which_epoch)
self.load_network(self.netG_B, 'G_B', which_epoch)
- #if self.isTrain:
- # self.load_network(self.netD_A, 'D_A', which_epoch)
- # self.load_network(self.netD_B, 'D_B', which_epoch)
+ if self.isTrain:
+ self.load_network(self.netD_A, 'D_A', which_epoch)
+ self.load_network(self.netD_B, 'D_B', which_epoch)
if self.isTrain:
self.old_lr = opt.lr
@@ -77,8 +77,6 @@ class CycleGANModel(BaseModel):
networks.print_network(self.netD_B)
print('-----------------------------------------------')
- self.step_count = 0
-
def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
@@ -86,20 +84,21 @@ class CycleGANModel(BaseModel):
self.input_A.resize_(input_A.size()).copy_(input_A)
self.input_B.resize_(input_B.size()).copy_(input_B)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
- self.image_paths2 = input['B_paths' if AtoB else 'A_paths']
def forward(self):
self.real_A = Variable(self.input_A)
self.real_B = Variable(self.input_B)
def test(self):
- self.real_A = Variable(self.input_A, volatile=True)
- self.fake_B = self.netG_A.forward(self.real_A)
- self.rec_A = self.netG_B.forward(self.fake_B)
+ real_A = Variable(self.input_A, volatile=True)
+ fake_B = self.netG_A(real_A)
+ self.rec_A = self.netG_B(fake_B).data
+ self.fake_B = fake_B.data
- self.real_B = Variable(self.input_B, volatile=True)
- self.fake_A = self.netG_B.forward(self.real_B)
- self.rec_B = self.netG_A.forward(self.fake_A)
+ real_B = Variable(self.input_B, volatile=True)
+ fake_A = self.netG_B(real_B)
+ self.rec_B = self.netG_A(fake_A).data
+ self.fake_A = fake_A.data
# get image paths
def get_image_paths(self):
@@ -107,10 +106,10 @@ class CycleGANModel(BaseModel):
def backward_D_basic(self, netD, real, fake):
# Real
- pred_real = netD.forward(real)
+ pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
- pred_fake = netD.forward(fake.detach())
+ pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
@@ -120,11 +119,13 @@ class CycleGANModel(BaseModel):
def backward_D_A(self):
fake_B = self.fake_B_pool.query(self.fake_B)
- self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
+ loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
+ self.loss_D_A = loss_D_A.data[0]
def backward_D_B(self):
fake_A = self.fake_A_pool.query(self.fake_A)
- self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
+ loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
+ self.loss_D_B = loss_D_B.data[0]
def backward_G(self):
lambda_idt = self.opt.identity
@@ -133,51 +134,59 @@ class CycleGANModel(BaseModel):
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed.
- self.idt_A = self.netG_A.forward(self.real_B)
- self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
+ idt_A = self.netG_A(self.real_B)
+ loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed.
- self.idt_B = self.netG_B.forward(self.real_A)
- self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
+ idt_B = self.netG_B(self.real_A)
+ loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt
+
+ self.idt_A = idt_A.data
+ self.idt_B = idt_B.data
+ self.loss_idt_A = loss_idt_A.data[0]
+ self.loss_idt_B = loss_idt_B.data[0]
else:
+ loss_idt_A = 0
+ loss_idt_B = 0
self.loss_idt_A = 0
self.loss_idt_B = 0
-
- # GAN loss
- # D_A(G_A(A))
- self.fake_B = self.netG_A.forward(self.real_A)
- pred_fake = self.netD_A.forward(self.fake_B)
- self.loss_G_A = self.criterionGAN(pred_fake, True)
- # D_B(G_B(B))
- self.fake_A = self.netG_B.forward(self.real_B)
- pred_fake = self.netD_B.forward(self.fake_A)
- self.loss_G_B = self.criterionGAN(pred_fake, True)
+
+ # GAN loss D_A(G_A(A))
+ fake_B = self.netG_A(self.real_A)
+ pred_fake = self.netD_A(fake_B)
+ loss_G_A = self.criterionGAN(pred_fake, True)
+
+ # GAN loss D_B(G_B(B))
+ fake_A = self.netG_B(self.real_B)
+ pred_fake = self.netD_B(fake_A)
+ loss_G_B = self.criterionGAN(pred_fake, True)
# Forward cycle loss
- self.rec_A = self.netG_B.forward(self.fake_B)
- self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
+ rec_A = self.netG_B(fake_B)
+ loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A
+
# Backward cycle loss
- self.rec_B = self.netG_A.forward(self.fake_A)
- self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
+ rec_B = self.netG_A(fake_A)
+ loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
# combined loss
- self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
- self.loss_G.backward()
+ loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
+ loss_G.backward()
+
+ self.fake_B = fake_B.data
+ self.fake_A = fake_A.data
+ self.rec_A = rec_A.data
+ self.rec_B = rec_B.data
+
+ self.loss_G_A = loss_G_A.data[0]
+ self.loss_G_B = loss_G_B.data[0]
+ self.loss_cycle_A = loss_cycle_A.data[0]
+ self.loss_cycle_B = loss_cycle_B.data[0]
def optimize_parameters(self):
- self.step_count += 1
# forward
self.forward()
# G_A and G_B
self.optimizer_G.zero_grad()
self.backward_G()
- if (self.loss_G != self.loss_G).sum().data[0] > 0:
- exit(1)
- #for w in self.netG_A.parameters():
- #print(w.grad.data)
- # if (w.grad.data != w.grad.data).sum() > 0:
- # print(w.grad.data)
- # exit(1)
- #print(self.image_paths, self.image_paths2)
- #return
self.optimizer_G.step()
# D_A
self.optimizer_D_A.zero_grad()
@@ -189,36 +198,26 @@ class CycleGANModel(BaseModel):
self.optimizer_D_B.step()
def get_current_errors(self):
- D_A = self.loss_D_A.data[0]
- G_A = self.loss_G_A.data[0]
- Cyc_A = self.loss_cycle_A.data[0]
- D_B = self.loss_D_B.data[0]
- G_B = self.loss_G_B.data[0]
- Cyc_B = self.loss_cycle_B.data[0]
+ ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
+ ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
if self.opt.identity > 0.0:
- idt_A = self.loss_idt_A.data[0]
- idt_B = self.loss_idt_B.data[0]
- return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
- ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
- else:
- return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
- ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])
+ ret_errors['idt_A'] = self.loss_idt_A
+ ret_errors['idt_B'] = self.loss_idt_B
+ return ret_errors
def get_current_visuals(self):
- real_A = util.tensor2im(self.real_A.data)
- fake_B = util.tensor2im(self.fake_B.data)
- rec_A = util.tensor2im(self.rec_A.data)
- real_B = util.tensor2im(self.real_B.data)
- fake_A = util.tensor2im(self.fake_A.data)
- rec_B = util.tensor2im(self.rec_B.data)
+ real_A = util.tensor2im(self.input_A)
+ fake_B = util.tensor2im(self.fake_B)
+ rec_A = util.tensor2im(self.rec_A)
+ real_B = util.tensor2im(self.input_B)
+ fake_A = util.tensor2im(self.fake_A)
+ rec_B = util.tensor2im(self.rec_B)
+ ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
+ ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
if self.opt.isTrain and self.opt.identity > 0.0:
- idt_A = util.tensor2im(self.idt_A.data)
- idt_B = util.tensor2im(self.idt_B.data)
- return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B),
- ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)])
- else:
- return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
- ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
+ ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
+ ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
+ return ret_visuals
def save(self, label):
self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
diff --git a/models/networks.py b/models/networks.py
index 965bacb..568f8c9 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -10,16 +10,15 @@ import numpy as np
###############################################################################
-
def weights_init_normal(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
- init.uniform(m.weight.data, 0.0, 0.02)
+ init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
- init.uniform(m.weight.data, 0.0, 0.02)
+ init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
@@ -27,11 +26,11 @@ def weights_init_xavier(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
- init.xavier_normal(m.weight.data, gain=1)
+ init.xavier_normal(m.weight.data, gain=0.02)
elif classname.find('Linear') != -1:
- init.xavier_normal(m.weight.data, gain=1)
+ init.xavier_normal(m.weight.data, gain=0.02)
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
@@ -43,7 +42,7 @@ def weights_init_kaiming(m):
elif classname.find('Linear') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
@@ -55,7 +54,7 @@ def weights_init_orthogonal(m):
elif classname.find('Linear') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
@@ -88,7 +87,7 @@ def get_norm_layer(norm_type='instance'):
def get_scheduler(optimizer, opt):
if opt.lr_policy == 'lambda':
def lambda_rule(epoch):
- lr_l = 1.0 - max(0, epoch - opt.niter) / float(opt.niter_decay+1)
+ lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
@@ -119,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
- netG.cuda(device_id=gpu_ids[0])
+ netG.cuda(gpu_ids[0])
init_weights(netG, init_type=init_type)
return netG
@@ -142,7 +141,7 @@ def define_D(input_nc, ndf, which_model_netD,
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
if use_gpu:
- netD.cuda(device_id=gpu_ids[0])
+ netD.cuda(gpu_ids[0])
init_weights(netD, init_type=init_type)
return netD
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 18ba53f..56adfc1 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -70,13 +70,13 @@ class Pix2PixModel(BaseModel):
def forward(self):
self.real_A = Variable(self.input_A)
- self.fake_B = self.netG.forward(self.real_A)
+ self.fake_B = self.netG(self.real_A)
self.real_B = Variable(self.input_B)
# no backprop gradients
def test(self):
self.real_A = Variable(self.input_A, volatile=True)
- self.fake_B = self.netG.forward(self.real_A)
+ self.fake_B = self.netG(self.real_A)
self.real_B = Variable(self.input_B, volatile=True)
# get image paths
@@ -86,14 +86,14 @@ class Pix2PixModel(BaseModel):
def backward_D(self):
# Fake
# stop backprop to the generator by detaching fake_B
- fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
- self.pred_fake = self.netD.forward(fake_AB.detach())
- self.loss_D_fake = self.criterionGAN(self.pred_fake, False)
+ fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data)
+ pred_fake = self.netD(fake_AB.detach())
+ self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real
real_AB = torch.cat((self.real_A, self.real_B), 1)
- self.pred_real = self.netD.forward(real_AB)
- self.loss_D_real = self.criterionGAN(self.pred_real, True)
+ pred_real = self.netD(real_AB)
+ self.loss_D_real = self.criterionGAN(pred_real, True)
# Combined loss
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
@@ -103,7 +103,7 @@ class Pix2PixModel(BaseModel):
def backward_G(self):
# First, G(A) should fake the discriminator
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
- pred_fake = self.netD.forward(fake_AB)
+ pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B
diff --git a/models/test_model.py b/models/test_model.py
index 4af1fe1..2ae2812 100644
--- a/models/test_model.py
+++ b/models/test_model.py
@@ -34,7 +34,7 @@ class TestModel(BaseModel):
def test(self):
self.real_A = Variable(self.input_A)
- self.fake_B = self.netG.forward(self.real_A)
+ self.fake_B = self.netG(self.real_A)
# get image paths
def get_image_paths(self):