From c9c46e1ef175862cefa6475dca900bb383cea53e Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Mon, 24 Sep 2018 15:15:16 +0200 Subject: make aligned dataset --- data/aligned_dataset.py | 7 ++++--- data/image_folder.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py index 41468d2..bb6947e 100755 --- a/data/aligned_dataset.py +++ b/data/aligned_dataset.py @@ -2,7 +2,7 @@ ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). import os.path from data.base_dataset import BaseDataset, get_params, get_transform, normalize -from data.image_folder import make_dataset +from data.image_folder import make_dataset, make_aligned_dataset from PIL import Image class AlignedDataset(BaseDataset): @@ -13,13 +13,14 @@ class AlignedDataset(BaseDataset): ### input A (label maps) dir_A = '_A' if self.opt.label_nc == 0 else '_label' self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A) - self.A_paths = sorted(make_dataset(self.dir_A)) ### input B (real images) if opt.isTrain: dir_B = '_B' if self.opt.label_nc == 0 else '_img' self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B) - self.B_paths = sorted(make_dataset(self.dir_B)) + self.A_paths, self.B_paths = make_aligned_dataset(self.dir_A, self.dir_B) + else: + self.A_paths = sorted(make_dataset(self.dir_A)) ### instance maps if not opt.no_instance: diff --git a/data/image_folder.py b/data/image_folder.py index df0141f..ba675a3 100755 --- a/data/image_folder.py +++ b/data/image_folder.py @@ -30,6 +30,20 @@ def make_dataset(dir): return images +def make_aligned_dataset(dir, dir_b): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + assert os.path.isdir(dir_b), '%s is not a valid directory' % dir_b + + files_A = [f for f in os.listdir(dir) if os.path.isfile(os.path.join(mypath, f) and is_image_file(f))] + files_B = [f for f in os.listdir(dir_b) if os.path.isfile(os.path.join(mypath, f) and is_image_file(f))] + images = sorted(list(set(files_A).intersection(files_B))) + # path = os.path.join(root, fname) + # images.append(path) + images_A = [os.path.join(dir, f) for f in images] + images_B = [os.path.join(dir_b, f) for f in images] + return images_A, images_B + def default_loader(path): return Image.open(path).convert('RGB') -- cgit v1.2.3-70-g09d2