diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/cycle_gan_model.py | 16 | ||||
| -rw-r--r-- | models/models.py | 1 | ||||
| -rw-r--r-- | models/networks.py | 5 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 3 |
4 files changed, 6 insertions, 19 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index b7b840d..85432bb 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -1,6 +1,4 @@ -import numpy as np import torch -import os from collections import OrderedDict from torch.autograd import Variable import itertools @@ -8,7 +6,6 @@ import util.util as util from util.image_pool import ImagePool from .base_model import BaseModel from . import networks -import sys class CycleGANModel(BaseModel): @@ -17,10 +14,6 @@ class CycleGANModel(BaseModel): def initialize(self, opt): BaseModel.initialize(self, opt) - - nb = opt.batchSize - size = opt.fineSize - # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) @@ -47,7 +40,6 @@ class CycleGANModel(BaseModel): self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: - self.old_lr = opt.lr self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions @@ -129,7 +121,7 @@ class CycleGANModel(BaseModel): self.loss_D_B = loss_D_B.data[0] def backward_G(self): - lambda_idt = self.opt.identity + lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss @@ -200,8 +192,8 @@ class CycleGANModel(BaseModel): def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), - ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) - if self.opt.identity > 0.0: + ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) + if self.opt.lambda_identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors @@ -215,7 +207,7 @@ class CycleGANModel(BaseModel): rec_B = util.tensor2im(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) - if self.opt.isTrain and self.opt.identity > 0.0: + if self.opt.isTrain and self.opt.lambda_identity > 0.0: ret_visuals['idt_A'] = util.tensor2im(self.idt_A) ret_visuals['idt_B'] = util.tensor2im(self.idt_B) return ret_visuals diff --git a/models/models.py b/models/models.py index d5bb9d8..39cc020 100644 --- a/models/models.py +++ b/models/models.py @@ -1,4 +1,3 @@ - def create_model(opt): model = None print(opt.model) diff --git a/models/networks.py b/models/networks.py index da2f59c..b118c6a 100644 --- a/models/networks.py +++ b/models/networks.py @@ -4,7 +4,6 @@ from torch.nn import init import functools from torch.autograd import Variable from torch.optim import lr_scheduler -import numpy as np ############################################################################### # Functions ############################################################################### @@ -434,6 +433,7 @@ class NLayerDiscriminator(nn.Module): else: return self.model(input) + class PixelDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): super(PixelDiscriminator, self).__init__() @@ -442,7 +442,7 @@ class PixelDiscriminator(nn.Module): use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d - + self.net = [ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(0.2, True), @@ -461,4 +461,3 @@ class PixelDiscriminator(nn.Module): return nn.parallel.data_parallel(self.net, input, self.gpu_ids) else: return self.net(input) - diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 74a941e..78f8d69 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -1,6 +1,4 @@ -import numpy as np import torch -import os from collections import OrderedDict from torch.autograd import Variable import util.util as util @@ -32,7 +30,6 @@ class Pix2PixModel(BaseModel): if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) - self.old_lr = opt.lr # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionL1 = torch.nn.L1Loss() |
