summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/base_data_loader.py6
-rw-r--r--data/base_dataset.py3
-rw-r--r--data/data_loader.py1
-rw-r--r--data/single_dataset.py1
-rw-r--r--data/unaligned_dataset.py4
5 files changed, 5 insertions, 10 deletions
diff --git a/data/base_data_loader.py b/data/base_data_loader.py
index 0e1deb5..ae5a168 100644
--- a/data/base_data_loader.py
+++ b/data/base_data_loader.py
@@ -1,14 +1,10 @@
-
class BaseDataLoader():
def __init__(self):
pass
-
+
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None
-
-
-
diff --git a/data/base_dataset.py b/data/base_dataset.py
index a061a05..7cfac54 100644
--- a/data/base_dataset.py
+++ b/data/base_dataset.py
@@ -2,6 +2,7 @@ import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
+
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
@@ -12,6 +13,7 @@ class BaseDataset(data.Dataset):
def initialize(self, opt):
pass
+
def get_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
@@ -36,6 +38,7 @@ def get_transform(opt):
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
+
def __scale_width(img, target_width):
ow, oh = img.size
if (ow == target_width):
diff --git a/data/data_loader.py b/data/data_loader.py
index 2a4433a..22b6a8f 100644
--- a/data/data_loader.py
+++ b/data/data_loader.py
@@ -1,4 +1,3 @@
-
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
diff --git a/data/single_dataset.py b/data/single_dataset.py
index f8b4f1d..12083b1 100644
--- a/data/single_dataset.py
+++ b/data/single_dataset.py
@@ -1,5 +1,4 @@
import os.path
-import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
index ad0c11b..2f59b2a 100644
--- a/data/unaligned_dataset.py
+++ b/data/unaligned_dataset.py
@@ -1,11 +1,10 @@
import os.path
-import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
-import PIL
import random
+
class UnalignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
@@ -24,7 +23,6 @@ class UnalignedDataset(BaseDataset):
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
- index_A = index % self.A_size
if self.opt.serial_batches:
index_B = index % self.B_size
else: