diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/base_data_loader.py | 6 | ||||
| -rw-r--r-- | data/base_dataset.py | 3 | ||||
| -rw-r--r-- | data/data_loader.py | 1 | ||||
| -rw-r--r-- | data/single_dataset.py | 1 | ||||
| -rw-r--r-- | data/unaligned_dataset.py | 4 |
5 files changed, 5 insertions, 10 deletions
diff --git a/data/base_data_loader.py b/data/base_data_loader.py index 0e1deb5..ae5a168 100644 --- a/data/base_data_loader.py +++ b/data/base_data_loader.py @@ -1,14 +1,10 @@ - class BaseDataLoader(): def __init__(self): pass - + def initialize(self, opt): self.opt = opt pass def load_data(): return None - - - diff --git a/data/base_dataset.py b/data/base_dataset.py index a061a05..7cfac54 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -2,6 +2,7 @@ import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms + class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() @@ -12,6 +13,7 @@ class BaseDataset(data.Dataset): def initialize(self, opt): pass + def get_transform(opt): transform_list = [] if opt.resize_or_crop == 'resize_and_crop': @@ -36,6 +38,7 @@ def get_transform(opt): (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) + def __scale_width(img, target_width): ow, oh = img.size if (ow == target_width): diff --git a/data/data_loader.py b/data/data_loader.py index 2a4433a..22b6a8f 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -1,4 +1,3 @@ - def CreateDataLoader(opt): from data.custom_dataset_data_loader import CustomDatasetDataLoader data_loader = CustomDatasetDataLoader() diff --git a/data/single_dataset.py b/data/single_dataset.py index f8b4f1d..12083b1 100644 --- a/data/single_dataset.py +++ b/data/single_dataset.py @@ -1,5 +1,4 @@ import os.path -import torchvision.transforms as transforms from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index ad0c11b..2f59b2a 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -1,11 +1,10 @@ import os.path -import torchvision.transforms as transforms from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image -import PIL import random + class UnalignedDataset(BaseDataset): def initialize(self, opt): self.opt = opt @@ -24,7 +23,6 @@ class UnalignedDataset(BaseDataset): def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] - index_A = index % self.A_size if self.opt.serial_batches: index_B = index % self.B_size else: |
