summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
authorTaesung Park <taesung_park@berkeley.edu>2017-12-10 23:04:41 -0800
committerTaesung Park <taesung_park@berkeley.edu>2017-12-10 23:04:41 -0800
commitf33f098be9b25c3b62523540c9c703af1db0b1c0 (patch)
tree9b51e547067b46ad8b55ddb34b207825550df867 /models/networks.py
parent3d2c534933b356dc313a620639a713cb940dc756 (diff)
parent2d96edbee5a488a7861833731a2cb71b23b55727 (diff)
merged conflicts
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py23
1 files changed, 11 insertions, 12 deletions
diff --git a/models/networks.py b/models/networks.py
index 965bacb..568f8c9 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -10,16 +10,15 @@ import numpy as np
###############################################################################
-
def weights_init_normal(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
- init.uniform(m.weight.data, 0.0, 0.02)
+ init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
- init.uniform(m.weight.data, 0.0, 0.02)
+ init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
@@ -27,11 +26,11 @@ def weights_init_xavier(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
- init.xavier_normal(m.weight.data, gain=1)
+ init.xavier_normal(m.weight.data, gain=0.02)
elif classname.find('Linear') != -1:
- init.xavier_normal(m.weight.data, gain=1)
+ init.xavier_normal(m.weight.data, gain=0.02)
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
@@ -43,7 +42,7 @@ def weights_init_kaiming(m):
elif classname.find('Linear') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
@@ -55,7 +54,7 @@ def weights_init_orthogonal(m):
elif classname.find('Linear') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
@@ -88,7 +87,7 @@ def get_norm_layer(norm_type='instance'):
def get_scheduler(optimizer, opt):
if opt.lr_policy == 'lambda':
def lambda_rule(epoch):
- lr_l = 1.0 - max(0, epoch - opt.niter) / float(opt.niter_decay+1)
+ lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
@@ -119,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
- netG.cuda(device_id=gpu_ids[0])
+ netG.cuda(gpu_ids[0])
init_weights(netG, init_type=init_type)
return netG
@@ -142,7 +141,7 @@ def define_D(input_nc, ndf, which_model_netD,
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
if use_gpu:
- netD.cuda(device_id=gpu_ids[0])
+ netD.cuda(gpu_ids[0])
init_weights(netD, init_type=init_type)
return netD