diff options
Diffstat (limited to 'models/pix2pix_model.py')
| -rw-r--r-- | models/pix2pix_model.py | 20 |
1 files changed, 8 insertions, 12 deletions
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index a524f2c..18ba53f 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -24,12 +24,12 @@ class Pix2PixModel(BaseModel): # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, - opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) + opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, - opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: @@ -43,10 +43,16 @@ class Pix2PixModel(BaseModel): self.criterionL1 = torch.nn.L1Loss() # initialize optimizers + self.schedulers = [] + self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D) + for optimizer in self.optimizers: + self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) @@ -134,13 +140,3 @@ class Pix2PixModel(BaseModel): def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids) - - def update_learning_rate(self): - lrd = self.opt.lr / self.opt.niter_decay - lr = self.old_lr - lrd - for param_group in self.optimizer_D.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_G.param_groups: - param_group['lr'] = lr - print('update learning rate: %f -> %f' % (self.old_lr, lr)) - self.old_lr = lr |
