summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSsnL <tongzhou.wang.1994@gmail.com>2017-11-09 16:08:30 -0500
committerSsnL <tongzhou.wang.1994@gmail.com>2017-11-09 16:15:05 -0500
commitc2fc8d442f1248231eab4b73e111665288b1e615 (patch)
tree9621879f1070cf1d99829fa020e87000f878a3fa
parenta24e24d67d88f75869f447690f7d994fe7d42e2d (diff)
update
-rw-r--r--models/base_model.py2
-rw-r--r--models/cycle_gan_model.py30
-rw-r--r--models/networks.py4
-rw-r--r--models/pix2pix_model.py10
-rw-r--r--models/test_model.py2
5 files changed, 23 insertions, 25 deletions
diff --git a/models/base_model.py b/models/base_model.py
index 646a014..9b55afe 100644
--- a/models/base_model.py
+++ b/models/base_model.py
@@ -44,7 +44,7 @@ class BaseModel():
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
- network.cuda(device_id=gpu_ids[0]) # network.cuda(device=gpu_ids[0]) for the latest version.
+ network.cuda(gpu_ids[0])
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index e840e7b..fe06823 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -91,13 +91,13 @@ class CycleGANModel(BaseModel):
def test(self):
real_A = Variable(self.input_A, volatile=True)
- fake_B = self.netG_A.forward(real_A)
- self.rec_A = self.netG_B.forward(fake_B).data
+ fake_B = self.netG_A(real_A)
+ self.rec_A = self.netG_B(fake_B).data
self.fake_B = fake_B.data
real_B = Variable(self.input_B, volatile=True)
- fake_A = self.netG_B.forward(real_B)
- self.rec_B = self.netG_A.forward(fake_A).data
+ fake_A = self.netG_B(real_B)
+ self.rec_B = self.netG_A(fake_A).data
self.fake_A = fake_A.data
# get image paths
@@ -106,10 +106,10 @@ class CycleGANModel(BaseModel):
def backward_D_basic(self, netD, real, fake):
# Real
- pred_real = netD.forward(real)
+ pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
- pred_fake = netD.forward(fake.detach())
+ pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
@@ -134,17 +134,16 @@ class CycleGANModel(BaseModel):
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed.
- idt_A = self.netG_A.forward(self.real_B)
+ idt_A = self.netG_A(self.real_B)
loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed.
- idt_B = self.netG_B.forward(self.real_A)
+ idt_B = self.netG_B(self.real_A)
loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt
self.idt_A = idt_A.data
self.idt_B = idt_B.data
self.loss_idt_A = loss_idt_A.data[0]
self.loss_idt_B = loss_idt_B.data[0]
-
else:
loss_idt_A = 0
loss_idt_B = 0
@@ -152,23 +151,22 @@ class CycleGANModel(BaseModel):
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
- fake_B = self.netG_A.forward(self.real_A)
- pred_fake = self.netD_A.forward(fake_B)
+ fake_B = self.netG_A(self.real_A)
+ pred_fake = self.netD_A(fake_B)
loss_G_A = self.criterionGAN(pred_fake, True)
# GAN loss D_B(G_B(B))
- fake_A = self.netG_B.forward(self.real_B)
- pred_fake = self.netD_B.forward(fake_A)
+ fake_A = self.netG_B(self.real_B)
+ pred_fake = self.netD_B(fake_A)
loss_G_B = self.criterionGAN(pred_fake, True)
# Forward cycle loss
- rec_A = self.netG_B.forward(fake_B)
+ rec_A = self.netG_B(fake_B)
loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A
# Backward cycle loss
- rec_B = self.netG_A.forward(fake_A)
+ rec_B = self.netG_A(fake_A)
loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
-
# combined loss
loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
loss_G.backward()
diff --git a/models/networks.py b/models/networks.py
index d071ac4..ec6573b 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -118,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
- netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version.
+ netG.cuda(gpu_ids[0])
init_weights(netG, init_type=init_type)
return netG
@@ -139,7 +139,7 @@ def define_D(input_nc, ndf, which_model_netD,
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
if use_gpu:
- netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version.
+ netD.cuda(gpu_ids[0])
init_weights(netD, init_type=init_type)
return netD
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 388a8d3..56adfc1 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -70,13 +70,13 @@ class Pix2PixModel(BaseModel):
def forward(self):
self.real_A = Variable(self.input_A)
- self.fake_B = self.netG.forward(self.real_A)
+ self.fake_B = self.netG(self.real_A)
self.real_B = Variable(self.input_B)
# no backprop gradients
def test(self):
self.real_A = Variable(self.input_A, volatile=True)
- self.fake_B = self.netG.forward(self.real_A)
+ self.fake_B = self.netG(self.real_A)
self.real_B = Variable(self.input_B, volatile=True)
# get image paths
@@ -87,12 +87,12 @@ class Pix2PixModel(BaseModel):
# Fake
# stop backprop to the generator by detaching fake_B
fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data)
- pred_fake = self.netD.forward(fake_AB.detach())
+ pred_fake = self.netD(fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real
real_AB = torch.cat((self.real_A, self.real_B), 1)
- pred_real = self.netD.forward(real_AB)
+ pred_real = self.netD(real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True)
# Combined loss
@@ -103,7 +103,7 @@ class Pix2PixModel(BaseModel):
def backward_G(self):
# First, G(A) should fake the discriminator
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
- pred_fake = self.netD.forward(fake_AB)
+ pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B
diff --git a/models/test_model.py b/models/test_model.py
index 4af1fe1..2ae2812 100644
--- a/models/test_model.py
+++ b/models/test_model.py
@@ -34,7 +34,7 @@ class TestModel(BaseModel):
def test(self):
self.real_A = Variable(self.input_A)
- self.fake_B = self.netG.forward(self.real_A)
+ self.fake_B = self.netG(self.real_A)
# get image paths
def get_image_paths(self):