summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/aligned_dataset.py22
-rw-r--r--data/base_data_loader.py6
-rw-r--r--data/base_dataset.py3
-rw-r--r--data/data_loader.py1
-rw-r--r--data/single_dataset.py1
-rw-r--r--data/unaligned_dataset.py4
6 files changed, 15 insertions, 22 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
index 8899cb2..f153f26 100644
--- a/data/aligned_dataset.py
+++ b/data/aligned_dataset.py
@@ -18,19 +18,17 @@ class AlignedDataset(BaseDataset):
def __getitem__(self, index):
AB_path = self.AB_paths[index]
AB = Image.open(AB_path).convert('RGB')
- AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC)
- AB = transforms.ToTensor()(AB)
+ w, h = AB.size
+ w2 = int(w / 2)
+ A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
+ B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
+ A = transforms.ToTensor()(A)
+ B = transforms.ToTensor()(B)
+ w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
+ h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
- w_total = AB.size(2)
- w = int(w_total / 2)
- h = AB.size(1)
- w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
- h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
-
- A = AB[:, h_offset:h_offset + self.opt.fineSize,
- w_offset:w_offset + self.opt.fineSize]
- B = AB[:, h_offset:h_offset + self.opt.fineSize,
- w + w_offset:w + w_offset + self.opt.fineSize]
+ A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
+ B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B)
diff --git a/data/base_data_loader.py b/data/base_data_loader.py
index 0e1deb5..ae5a168 100644
--- a/data/base_data_loader.py
+++ b/data/base_data_loader.py
@@ -1,14 +1,10 @@
-
class BaseDataLoader():
def __init__(self):
pass
-
+
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None
-
-
-
diff --git a/data/base_dataset.py b/data/base_dataset.py
index a061a05..7cfac54 100644
--- a/data/base_dataset.py
+++ b/data/base_dataset.py
@@ -2,6 +2,7 @@ import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
+
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
@@ -12,6 +13,7 @@ class BaseDataset(data.Dataset):
def initialize(self, opt):
pass
+
def get_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
@@ -36,6 +38,7 @@ def get_transform(opt):
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
+
def __scale_width(img, target_width):
ow, oh = img.size
if (ow == target_width):
diff --git a/data/data_loader.py b/data/data_loader.py
index 2a4433a..22b6a8f 100644
--- a/data/data_loader.py
+++ b/data/data_loader.py
@@ -1,4 +1,3 @@
-
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
diff --git a/data/single_dataset.py b/data/single_dataset.py
index f8b4f1d..12083b1 100644
--- a/data/single_dataset.py
+++ b/data/single_dataset.py
@@ -1,5 +1,4 @@
import os.path
-import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
index ad0c11b..2f59b2a 100644
--- a/data/unaligned_dataset.py
+++ b/data/unaligned_dataset.py
@@ -1,11 +1,10 @@
import os.path
-import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
-import PIL
import random
+
class UnalignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
@@ -24,7 +23,6 @@ class UnalignedDataset(BaseDataset):
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
- index_A = index % self.A_size
if self.opt.serial_batches:
index_B = index % self.B_size
else: