summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-10-19 21:14:48 -0700
committerjunyanz <junyanz@berkeley.edu>2017-10-19 21:14:48 -0700
commit9d1bc76e6a4f791a25db1179c7c2b4c62a8d55cd (patch)
treef8a3ad91d77f8ee2a1e377621ff76e15b3e1ec98 /models/networks.py
parent2a344ccd6f80ce0435d2c58ccfd94c76dd423b1a (diff)
fix learning rate
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/models/networks.py b/models/networks.py
index 19169c5..51e3f25 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -10,7 +10,6 @@ import numpy as np
###############################################################################
-
def weights_init_normal(m):
classname = m.__class__.__name__
# print(classname)
@@ -87,8 +86,8 @@ def get_norm_layer(norm_type='instance'):
def get_scheduler(optimizer, opt):
if opt.lr_policy == 'lambda':
- def lambda_rule(epoch): # epoch ranges from [1, opt.niter+opt.niter_decay]
- lr_l = 1.0 - max(0, epoch - opt.niter + 1) / float(opt.niter_decay + 1)
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':