diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-08-26 19:04:12 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-08-26 19:04:12 -0700 |
| commit | 174682626f21b4222a1cb294348e59ad3b260eb7 (patch) | |
| tree | 85d7c59f3e2aa28b0d96836ae28b7621201b4cc7 | |
| parent | 8c7cd2e23cdfa45e84301e66ff81bbc8a369aca7 (diff) | |
fix the unaligned dataset
| -rw-r--r-- | data/unaligned_dataset.py | 8 | ||||
| -rw-r--r-- | models/networks.py | 6 |
2 files changed, 8 insertions, 6 deletions
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index 3864bf3..b162869 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -5,7 +5,7 @@ from data.image_folder import make_dataset from PIL import Image import PIL from pdb import set_trace as st - +import random class UnalignedDataset(BaseDataset): def initialize(self, opt): @@ -25,8 +25,10 @@ class UnalignedDataset(BaseDataset): def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] - B_path = self.B_paths[index % self.B_size] - + index_A = index % self.A_size + index_B = random.randint(0, self.B_size) + B_path = self.B_paths[index_B] + # print('(A, B) = (%d, %d)' % (index_A, index_B)) A_img = Image.open(A_path).convert('RGB') B_img = Image.open(B_path).convert('RGB') diff --git a/models/networks.py b/models/networks.py index db36ac4..585b940 100644 --- a/models/networks.py +++ b/models/networks.py @@ -162,7 +162,7 @@ class ResnetGenerator(nn.Module): mult = 2**n_downsampling for i in range(n_blocks): - model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)] + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) @@ -189,9 +189,9 @@ class ResnetGenerator(nn.Module): class ResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): super(ResnetBlock, self).__init__() - self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout) + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) - def build_conv_block(self, dim, padding_type, norm_layer, use_dropout): + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): conv_block = [] p = 0 if padding_type == 'reflect': |
