summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaesung Park <taesung89@gmail.com>2017-12-01 22:13:51 -0800
committerTaesung Park <taesung89@gmail.com>2017-12-01 22:13:51 -0800
commit7bda734dd7f3466d5d55afe80b97542b1b12bdb5 (patch)
tree82cfdacd331189c4367745a510f09aef833c77a6
parent1615932f9180a7a9df92f33fbb8749aec432d3d9 (diff)
changed the gain of xavier initialization from 1 to 0.02. implemented serial_batches option in unaligned dataset
-rw-r--r--data/unaligned_dataset.py5
-rw-r--r--models/networks.py4
2 files changed, 6 insertions, 3 deletions
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
index c5e5460..ad0c11b 100644
--- a/data/unaligned_dataset.py
+++ b/data/unaligned_dataset.py
@@ -25,7 +25,10 @@ class UnalignedDataset(BaseDataset):
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
index_A = index % self.A_size
- index_B = random.randint(0, self.B_size - 1)
+ if self.opt.serial_batches:
+ index_B = index % self.B_size
+ else:
+ index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
# print('(A, B) = (%d, %d)' % (index_A, index_B))
A_img = Image.open(A_path).convert('RGB')
diff --git a/models/networks.py b/models/networks.py
index e6e0a87..3c54138 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -26,9 +26,9 @@ def weights_init_xavier(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
- init.xavier_normal(m.weight.data, gain=1)
+ init.xavier_normal(m.weight.data, gain=0.02)
elif classname.find('Linear') != -1:
- init.xavier_normal(m.weight.data, gain=1)
+ init.xavier_normal(m.weight.data, gain=0.02)
elif classname.find('BatchNorm2d') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)