summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/models/networks.py b/models/networks.py
index ec6573b..e6e0a87 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -14,11 +14,11 @@ def weights_init_normal(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
- init.uniform(m.weight.data, 0.0, 0.02)
+ init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
- init.uniform(m.weight.data, 0.0, 0.02)
+ init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
@@ -30,7 +30,7 @@ def weights_init_xavier(m):
elif classname.find('Linear') != -1:
init.xavier_normal(m.weight.data, gain=1)
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
@@ -42,7 +42,7 @@ def weights_init_kaiming(m):
elif classname.find('Linear') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
@@ -54,7 +54,7 @@ def weights_init_orthogonal(m):
elif classname.find('Linear') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('BatchNorm2d') != -1:
- init.uniform(m.weight.data, 1.0, 0.02)
+ init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)