From 6b8e96c4bbd73a1e1d4e126d795a26fd0dae983c Mon Sep 17 00:00:00 2001 From: junyanz Date: Sat, 4 Nov 2017 02:27:18 -0700 Subject: add update_html_freq flag --- models/pix2pix_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'models/pix2pix_model.py') diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 18ba53f..8cd494f 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -87,12 +87,12 @@ class Pix2PixModel(BaseModel): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) - self.pred_fake = self.netD.forward(fake_AB.detach()) - self.loss_D_fake = self.criterionGAN(self.pred_fake, False) + pred_fake = self.netD.forward(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) - self.pred_real = self.netD.forward(real_AB) + pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss -- cgit v1.2.3-70-g09d2 From 7a9021d4f131ee059d49ff9b2d135e6543f75763 Mon Sep 17 00:00:00 2001 From: junyanz Date: Sat, 4 Nov 2017 02:47:39 -0700 Subject: fix small issues --- models/pix2pix_model.py | 4 ++-- util/image_pool.py | 2 +- util/visualizer.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) (limited to 'models/pix2pix_model.py') diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 8cd494f..388a8d3 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -86,14 +86,14 @@ class Pix2PixModel(BaseModel): def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B - fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) + fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data) pred_fake = self.netD.forward(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD.forward(real_AB) - self.loss_D_real = self.criterionGAN(self.pred_real, True) + self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 diff --git a/util/image_pool.py b/util/image_pool.py index 5a242e6..ada1627 100644 --- a/util/image_pool.py +++ b/util/image_pool.py @@ -13,7 +13,7 @@ class ImagePool(): def query(self, images): if self.pool_size == 0: - return images + return Variable(images) return_images = [] for image in images: image = torch.unsqueeze(image, 0) diff --git a/util/visualizer.py b/util/visualizer.py index 22fe9da..e6e7cba 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -13,11 +13,11 @@ class Visualizer(): self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name + self.opt = opt self.saved = False if self.display_id > 0: import visdom self.vis = visdom.Visdom(port=opt.display_port) - self.display_single_pane_ncols = opt.display_single_pane_ncols if self.use_html: self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') @@ -35,13 +35,13 @@ class Visualizer(): # |visuals|: dictionary of images to display or save def display_current_results(self, visuals, epoch, save_result): if self.display_id > 0: # show images in the browser - if self.display_single_pane_ncols > 0: + ncols = self.opt.display_single_pane_ncols + if ncols > 0: h, w = next(iter(visuals.values())).shape[:2] table_css = """""" % (w, h) - ncols = self.display_single_pane_ncols title = self.name label_html = '' label_html_row = '' @@ -76,7 +76,7 @@ class Visualizer(): idx += 1 if self.use_html and (save_result or not self.saved): # save images to a html file - self.saved = True + self.saved = True for label, image_numpy in visuals.items(): img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) -- cgit v1.2.3-70-g09d2 From c2fc8d442f1248231eab4b73e111665288b1e615 Mon Sep 17 00:00:00 2001 From: SsnL Date: Thu, 9 Nov 2017 16:08:30 -0500 Subject: update --- models/base_model.py | 2 +- models/cycle_gan_model.py | 30 ++++++++++++++---------------- models/networks.py | 4 ++-- models/pix2pix_model.py | 10 +++++----- models/test_model.py | 2 +- 5 files changed, 23 insertions(+), 25 deletions(-) (limited to 'models/pix2pix_model.py') diff --git a/models/base_model.py b/models/base_model.py index 646a014..9b55afe 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -44,7 +44,7 @@ class BaseModel(): save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): - network.cuda(device_id=gpu_ids[0]) # network.cuda(device=gpu_ids[0]) for the latest version. + network.cuda(gpu_ids[0]) # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index e840e7b..fe06823 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -91,13 +91,13 @@ class CycleGANModel(BaseModel): def test(self): real_A = Variable(self.input_A, volatile=True) - fake_B = self.netG_A.forward(real_A) - self.rec_A = self.netG_B.forward(fake_B).data + fake_B = self.netG_A(real_A) + self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_B, volatile=True) - fake_A = self.netG_B.forward(real_B) - self.rec_B = self.netG_A.forward(fake_A).data + fake_A = self.netG_B(real_B) + self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data # get image paths @@ -106,10 +106,10 @@ class CycleGANModel(BaseModel): def backward_D_basic(self, netD, real, fake): # Real - pred_real = netD.forward(real) + pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake - pred_fake = netD.forward(fake.detach()) + pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 @@ -134,17 +134,16 @@ class CycleGANModel(BaseModel): # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. - idt_A = self.netG_A.forward(self.real_B) + idt_A = self.netG_A(self.real_B) loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. - idt_B = self.netG_B.forward(self.real_A) + idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] self.loss_idt_B = loss_idt_B.data[0] - else: loss_idt_A = 0 loss_idt_B = 0 @@ -152,23 +151,22 @@ class CycleGANModel(BaseModel): self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) - fake_B = self.netG_A.forward(self.real_A) - pred_fake = self.netD_A.forward(fake_B) + fake_B = self.netG_A(self.real_A) + pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) - fake_A = self.netG_B.forward(self.real_B) - pred_fake = self.netD_B.forward(fake_A) + fake_A = self.netG_B(self.real_B) + pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss - rec_A = self.netG_B.forward(fake_B) + rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # Backward cycle loss - rec_B = self.netG_A.forward(fake_A) + rec_B = self.netG_A(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B - # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() diff --git a/models/networks.py b/models/networks.py index d071ac4..ec6573b 100644 --- a/models/networks.py +++ b/models/networks.py @@ -118,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: - netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. + netG.cuda(gpu_ids[0]) init_weights(netG, init_type=init_type) return netG @@ -139,7 +139,7 @@ def define_D(input_nc, ndf, which_model_netD, raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: - netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version. + netD.cuda(gpu_ids[0]) init_weights(netD, init_type=init_type) return netD diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 388a8d3..56adfc1 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -70,13 +70,13 @@ class Pix2PixModel(BaseModel): def forward(self): self.real_A = Variable(self.input_A) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths @@ -87,12 +87,12 @@ class Pix2PixModel(BaseModel): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data) - pred_fake = self.netD.forward(fake_AB.detach()) + pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) - pred_real = self.netD.forward(real_AB) + pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss @@ -103,7 +103,7 @@ class Pix2PixModel(BaseModel): def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) - pred_fake = self.netD.forward(fake_AB) + pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B diff --git a/models/test_model.py b/models/test_model.py index 4af1fe1..2ae2812 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -34,7 +34,7 @@ class TestModel(BaseModel): def test(self): self.real_A = Variable(self.input_A) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) # get image paths def get_image_paths(self): -- cgit v1.2.3-70-g09d2