diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-10-06 10:46:43 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-10-06 10:46:43 -0700 |
| commit | 7800d516596f1a25986b458cddf8b8785bcc7df8 (patch) | |
| tree | 56d57350e7104393f939ec7cc2e07c96840aaa27 /models/cycle_gan_model.py | |
| parent | e986144cee13a921fd3ad68d564f820e8f7dd3b0 (diff) | |
support nc=1, add new leaerning rate policy and new initialization
Diffstat (limited to 'models/cycle_gan_model.py')
| -rw-r--r-- | models/cycle_gan_model.py | 20 |
1 files changed, 7 insertions, 13 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index b3c52c7..c6b336c 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -61,6 +61,13 @@ class CycleGANModel(BaseModel): lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers = [] + self.schedulers = [] + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D_A) + self.optimizers.append(self.optimizer_D_B) + for optimizer in self.optimizers: + self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) @@ -204,16 +211,3 @@ class CycleGANModel(BaseModel): self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) - - def update_learning_rate(self): - lrd = self.opt.lr / self.opt.niter_decay - lr = self.old_lr - lrd - for param_group in self.optimizer_D_A.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_D_B.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_G.param_groups: - param_group['lr'] = lr - - print('update learning rate: %f -> %f' % (self.old_lr, lr)) - self.old_lr = lr |
