summaryrefslogtreecommitdiff
path: root/models/cycle_gan_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/cycle_gan_model.py')
-rw-r--r--models/cycle_gan_model.py20
1 files changed, 7 insertions, 13 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index b3c52c7..c6b336c 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -61,6 +61,13 @@ class CycleGANModel(BaseModel):
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers = []
+ self.schedulers = []
+ self.optimizers.append(self.optimizer_G)
+ self.optimizers.append(self.optimizer_D_A)
+ self.optimizers.append(self.optimizer_D_B)
+ for optimizer in self.optimizers:
+ self.schedulers.append(networks.get_scheduler(optimizer, opt))
print('---------- Networks initialized -------------')
networks.print_network(self.netG_A)
@@ -204,16 +211,3 @@ class CycleGANModel(BaseModel):
self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
-
- def update_learning_rate(self):
- lrd = self.opt.lr / self.opt.niter_decay
- lr = self.old_lr - lrd
- for param_group in self.optimizer_D_A.param_groups:
- param_group['lr'] = lr
- for param_group in self.optimizer_D_B.param_groups:
- param_group['lr'] = lr
- for param_group in self.optimizer_G.param_groups:
- param_group['lr'] = lr
-
- print('update learning rate: %f -> %f' % (self.old_lr, lr))
- self.old_lr = lr