diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-06-12 23:52:56 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-06-12 23:52:56 -0700 |
| commit | e6858e35f0a08c6139c133122d222d0d85e8005d (patch) | |
| tree | 2647ff13a164c684113eab455123394a49a65dad /data/aligned_dataset.py | |
| parent | 3b72a659c38141e502b74bee65ca08d51dc3eabf (diff) | |
update dataset mode
Diffstat (limited to 'data/aligned_dataset.py')
| -rw-r--r-- | data/aligned_dataset.py | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py new file mode 100644 index 0000000..0f45c40 --- /dev/null +++ b/data/aligned_dataset.py @@ -0,0 +1,56 @@ +import os.path +import random +import torchvision.transforms as transforms +import torch +from data.base_dataset import BaseDataset +from data.image_folder import make_dataset +from PIL import Image + + +class AlignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_AB = os.path.join(opt.dataroot, opt.phase) + + self.AB_paths = sorted(make_dataset(self.dir_AB)) + + assert(opt.resize_or_crop == 'resize_and_crop') + + transform_list = [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + + self.transform = transforms.Compose(transform_list) + + def __getitem__(self, index): + AB_path = self.AB_paths[index] + AB = Image.open(AB_path).convert('RGB') + AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC) + AB = self.transform(AB) + + w_total = AB.size(2) + w = int(w_total / 2) + h = AB.size(1) + w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1)) + h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1)) + + A = AB[:, h_offset:h_offset + self.opt.fineSize, + w_offset:w_offset + self.opt.fineSize] + B = AB[:, h_offset:h_offset + self.opt.fineSize, + w + w_offset:w + w_offset + self.opt.fineSize] + + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(A.size(2) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A = A.index_select(2, idx) + B = B.index_select(2, idx) + + return {'A': A, 'B': B, + 'A_paths': AB_path, 'B_paths': AB_path} + + def __len__(self): + return len(self.AB_paths) + + def name(self): + return 'AlignedDataset' |
