summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-09-24 15:15:16 +0200
committerJules Laplace <julescarbon@gmail.com>2018-09-24 15:15:16 +0200
commitc9c46e1ef175862cefa6475dca900bb383cea53e (patch)
tree9403f252fc7b7b6c25d4d69867c277d1aa8b6213
parent59af4ea5a1084db67c5470e2a38304bd9e5c0760 (diff)
make aligned dataset
-rwxr-xr-xdata/aligned_dataset.py7
-rwxr-xr-xdata/image_folder.py14
2 files changed, 18 insertions, 3 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
index 41468d2..bb6947e 100755
--- a/data/aligned_dataset.py
+++ b/data/aligned_dataset.py
@@ -2,7 +2,7 @@
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
-from data.image_folder import make_dataset
+from data.image_folder import make_dataset, make_aligned_dataset
from PIL import Image
class AlignedDataset(BaseDataset):
@@ -13,13 +13,14 @@ class AlignedDataset(BaseDataset):
### input A (label maps)
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
- self.A_paths = sorted(make_dataset(self.dir_A))
### input B (real images)
if opt.isTrain:
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
- self.B_paths = sorted(make_dataset(self.dir_B))
+ self.A_paths, self.B_paths = make_aligned_dataset(self.dir_A, self.dir_B)
+ else:
+ self.A_paths = sorted(make_dataset(self.dir_A))
### instance maps
if not opt.no_instance:
diff --git a/data/image_folder.py b/data/image_folder.py
index df0141f..ba675a3 100755
--- a/data/image_folder.py
+++ b/data/image_folder.py
@@ -30,6 +30,20 @@ def make_dataset(dir):
return images
+def make_aligned_dataset(dir, dir_b):
+ images = []
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
+ assert os.path.isdir(dir_b), '%s is not a valid directory' % dir_b
+
+ files_A = [f for f in os.listdir(dir) if os.path.isfile(os.path.join(mypath, f) and is_image_file(f))]
+ files_B = [f for f in os.listdir(dir_b) if os.path.isfile(os.path.join(mypath, f) and is_image_file(f))]
+ images = sorted(list(set(files_A).intersection(files_B)))
+ # path = os.path.join(root, fname)
+ # images.append(path)
+ images_A = [os.path.join(dir, f) for f in images]
+ images_B = [os.path.join(dir_b, f) for f in images]
+ return images_A, images_B
+
def default_loader(path):
return Image.open(path).convert('RGB')