diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/aligned_data_loader.py | 17 | ||||
| -rw-r--r-- | data/unaligned_data_loader.py | 16 |
2 files changed, 21 insertions, 12 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py index bea3531..a1efde8 100644 --- a/data/aligned_data_loader.py +++ b/data/aligned_data_loader.py @@ -4,19 +4,26 @@ import torchvision.transforms as transforms from data.base_data_loader import BaseDataLoader from data.image_folder import ImageFolder from pdb import set_trace as st +# pip install future --upgrade from builtins import object class PairedData(object): - def __init__(self, data_loader, fineSize): + def __init__(self, data_loader, fineSize, max_dataset_size): self.data_loader = data_loader self.fineSize = fineSize + self.max_dataset_size = max_dataset_size # st() def __iter__(self): self.data_loader_iter = iter(self.data_loader) + self.iter = 0 return self def __next__(self): + self.iter += 1 + if self.iter > self.max_dataset_size: + raise StopIteration + AB, AB_paths = next(self.data_loader_iter) w_total = AB.size(3) w = int(w_total / 2) @@ -24,7 +31,6 @@ class PairedData(object): w_offset = random.randint(0, max(0, w - self.fineSize - 1)) h_offset = random.randint(0, max(0, h - self.fineSize - 1)) - A = AB[:, :, h_offset:h_offset + self.fineSize, w_offset:w_offset + self.fineSize] B = AB[:, :, h_offset:h_offset + self.fineSize, @@ -39,8 +45,7 @@ class AlignedDataLoader(BaseDataLoader): self.fineSize = opt.fineSize transform = transforms.Compose([ # TODO: Scale - #transforms.Scale((opt.loadSize * 2, opt.loadSize)), - #transforms.CenterCrop(opt.fineSize), + transforms.Scale(opt.loadSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) @@ -55,7 +60,7 @@ class AlignedDataLoader(BaseDataLoader): num_workers=int(self.opt.nThreads)) self.dataset = dataset - self.paired_data = PairedData(data_loader, opt.fineSize) + self.paired_data = PairedData(data_loader, opt.fineSize, opt.max_dataset_size) def name(self): return 'AlignedDataLoader' @@ -64,4 +69,4 @@ class AlignedDataLoader(BaseDataLoader): return self.paired_data def __len__(self): - return len(self.dataset) + return min(len(self.dataset), self.opt.max_dataset_size) diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py index 4a06510..77f9274 100644 --- a/data/unaligned_data_loader.py +++ b/data/unaligned_data_loader.py @@ -2,21 +2,24 @@ import torch.utils.data import torchvision.transforms as transforms from data.base_data_loader import BaseDataLoader from data.image_folder import ImageFolder +# pip install future --upgrade from builtins import object from pdb import set_trace as st class PairedData(object): - def __init__(self, data_loader_A, data_loader_B): + def __init__(self, data_loader_A, data_loader_B, max_dataset_size): self.data_loader_A = data_loader_A self.data_loader_B = data_loader_B self.stop_A = False self.stop_B = False + self.max_dataset_size = max_dataset_size def __iter__(self): self.stop_A = False self.stop_B = False self.data_loader_A_iter = iter(self.data_loader_A) self.data_loader_B_iter = iter(self.data_loader_B) + self.iter = 0 return self def __next__(self): @@ -29,20 +32,21 @@ class PairedData(object): self.stop_A = True self.data_loader_A_iter = iter(self.data_loader_A) A, A_paths = next(self.data_loader_A_iter) + try: B, B_paths = next(self.data_loader_B_iter) - except StopIteration: if B is None or B_paths is None: self.stop_B = True self.data_loader_B_iter = iter(self.data_loader_B) B, B_paths = next(self.data_loader_B_iter) - if self.stop_A and self.stop_B: + if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size: self.stop_A = False self.stop_B = False raise StopIteration() else: + self.iter += 1 return {'A': A, 'A_paths': A_paths, 'B': B, 'B_paths': B_paths} @@ -51,7 +55,7 @@ class UnalignedDataLoader(BaseDataLoader): BaseDataLoader.initialize(self, opt) transform = transforms.Compose([ transforms.Scale(opt.loadSize), - transforms.CenterCrop(opt.fineSize), + transforms.RandomCrop(opt.fineSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) @@ -75,7 +79,7 @@ class UnalignedDataLoader(BaseDataLoader): num_workers=int(self.opt.nThreads)) self.dataset_A = dataset_A self.dataset_B = dataset_B - self.paired_data = PairedData(data_loader_A, data_loader_B) + self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.max_dataset_size) def name(self): return 'UnalignedDataLoader' @@ -84,4 +88,4 @@ class UnalignedDataLoader(BaseDataLoader): return self.paired_data def __len__(self): - return max(len(self.dataset_A), len(self.dataset_B)) + return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size) |
