From 9d1bc76e6a4f791a25db1179c7c2b4c62a8d55cd Mon Sep 17 00:00:00 2001 From: junyanz Date: Thu, 19 Oct 2017 21:14:48 -0700 Subject: fix learning rate --- models/networks.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'models/networks.py') diff --git a/models/networks.py b/models/networks.py index 19169c5..51e3f25 100644 --- a/models/networks.py +++ b/models/networks.py @@ -10,7 +10,6 @@ import numpy as np ############################################################################### - def weights_init_normal(m): classname = m.__class__.__name__ # print(classname) @@ -87,8 +86,8 @@ def get_norm_layer(norm_type='instance'): def get_scheduler(optimizer, opt): if opt.lr_policy == 'lambda': - def lambda_rule(epoch): # epoch ranges from [1, opt.niter+opt.niter_decay] - lr_l = 1.0 - max(0, epoch - opt.niter + 1) / float(opt.niter_decay + 1) + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) return lr_l scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif opt.lr_policy == 'step': -- cgit v1.2.3-70-g09d2