summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-08-26 19:04:12 -0700
committerjunyanz <junyanz@berkeley.edu>2017-08-26 19:04:12 -0700
commit174682626f21b4222a1cb294348e59ad3b260eb7 (patch)
tree85d7c59f3e2aa28b0d96836ae28b7621201b4cc7
parent8c7cd2e23cdfa45e84301e66ff81bbc8a369aca7 (diff)
fix the unaligned dataset
-rw-r--r--data/unaligned_dataset.py8
-rw-r--r--models/networks.py6
2 files changed, 8 insertions, 6 deletions
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
index 3864bf3..b162869 100644
--- a/data/unaligned_dataset.py
+++ b/data/unaligned_dataset.py
@@ -5,7 +5,7 @@ from data.image_folder import make_dataset
from PIL import Image
import PIL
from pdb import set_trace as st
-
+import random
class UnalignedDataset(BaseDataset):
def initialize(self, opt):
@@ -25,8 +25,10 @@ class UnalignedDataset(BaseDataset):
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
- B_path = self.B_paths[index % self.B_size]
-
+ index_A = index % self.A_size
+ index_B = random.randint(0, self.B_size)
+ B_path = self.B_paths[index_B]
+ # print('(A, B) = (%d, %d)' % (index_A, index_B))
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
diff --git a/models/networks.py b/models/networks.py
index db36ac4..585b940 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -162,7 +162,7 @@ class ResnetGenerator(nn.Module):
mult = 2**n_downsampling
for i in range(n_blocks):
- model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)]
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
@@ -189,9 +189,9 @@ class ResnetGenerator(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout)
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
- def build_conv_block(self, dim, padding_type, norm_layer, use_dropout):
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':