summaryrefslogtreecommitdiff
path: root/models/pix2pix_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/pix2pix_model.py')
-rw-r--r--models/pix2pix_model.py20
1 files changed, 8 insertions, 12 deletions
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index a524f2c..18ba53f 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -24,12 +24,12 @@ class Pix2PixModel(BaseModel):
# load/define networks
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
- opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids)
+ opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
opt.which_model_netD,
- opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
+ opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
if not self.isTrain or opt.continue_train:
self.load_network(self.netG, 'G', opt.which_epoch)
if self.isTrain:
@@ -43,10 +43,16 @@ class Pix2PixModel(BaseModel):
self.criterionL1 = torch.nn.L1Loss()
# initialize optimizers
+ self.schedulers = []
+ self.optimizers = []
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers.append(self.optimizer_G)
+ self.optimizers.append(self.optimizer_D)
+ for optimizer in self.optimizers:
+ self.schedulers.append(networks.get_scheduler(optimizer, opt))
print('---------- Networks initialized -------------')
networks.print_network(self.netG)
@@ -134,13 +140,3 @@ class Pix2PixModel(BaseModel):
def save(self, label):
self.save_network(self.netG, 'G', label, self.gpu_ids)
self.save_network(self.netD, 'D', label, self.gpu_ids)
-
- def update_learning_rate(self):
- lrd = self.opt.lr / self.opt.niter_decay
- lr = self.old_lr - lrd
- for param_group in self.optimizer_D.param_groups:
- param_group['lr'] = lr
- for param_group in self.optimizer_G.param_groups:
- param_group['lr'] = lr
- print('update learning rate: %f -> %f' % (self.old_lr, lr))
- self.old_lr = lr