diff options
| -rw-r--r-- | README.md | 2 | ||||
| -rw-r--r-- | data/__init__.py | 54 | ||||
| -rw-r--r-- | data/custom_dataset_data_loader.py | 47 | ||||
| -rw-r--r-- | data/data_loader.py | 6 | ||||
| -rw-r--r-- | make_dataset_aligned.py | 63 | ||||
| -rw-r--r-- | models/__init__.py | 20 | ||||
| -rw-r--r-- | models/models.py | 20 | ||||
| -rw-r--r-- | test.py | 52 | ||||
| -rw-r--r-- | train.py | 4 |
9 files changed, 104 insertions, 164 deletions
@@ -4,7 +4,7 @@ # CycleGAN and pix2pix in PyTorch -This is our ongoing PyTorch implementation for both unpaired and paired image-to-image translation. +This is our PyTorch implementation for both unpaired and paired image-to-image translation. The code was written by [Jun-Yan Zhu](https://github.com/junyanz) and [Taesung Park](https://github.com/taesung89). diff --git a/data/__init__.py b/data/__init__.py index e69de29..ef581e7 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -0,0 +1,54 @@ +import torch.utils.data +from data.base_data_loader import BaseDataLoader + + +def CreateDataLoader(opt): + data_loader = CustomDatasetDataLoader() + print(data_loader.name()) + data_loader.initialize(opt) + return data_loader + + +def CreateDataset(opt): + dataset = None + if opt.dataset_mode == 'aligned': + from data.aligned_dataset import AlignedDataset + dataset = AlignedDataset() + elif opt.dataset_mode == 'unaligned': + from data.unaligned_dataset import UnalignedDataset + dataset = UnalignedDataset() + elif opt.dataset_mode == 'single': + from data.single_dataset import SingleDataset + dataset = SingleDataset() + else: + raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) + + print("dataset [%s] was created" % (dataset.name())) + dataset.initialize(opt) + return dataset + + +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = CreateDataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + for i, data in enumerate(self.dataloader): + if i >= self.opt.max_dataset_size: + break + yield data diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py deleted file mode 100644 index 787946f..0000000 --- a/data/custom_dataset_data_loader.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch.utils.data -from data.base_data_loader import BaseDataLoader - - -def CreateDataset(opt): - dataset = None - if opt.dataset_mode == 'aligned': - from data.aligned_dataset import AlignedDataset - dataset = AlignedDataset() - elif opt.dataset_mode == 'unaligned': - from data.unaligned_dataset import UnalignedDataset - dataset = UnalignedDataset() - elif opt.dataset_mode == 'single': - from data.single_dataset import SingleDataset - dataset = SingleDataset() - else: - raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) - - print("dataset [%s] was created" % (dataset.name())) - dataset.initialize(opt) - return dataset - - -class CustomDatasetDataLoader(BaseDataLoader): - def name(self): - return 'CustomDatasetDataLoader' - - def initialize(self, opt): - BaseDataLoader.initialize(self, opt) - self.dataset = CreateDataset(opt) - self.dataloader = torch.utils.data.DataLoader( - self.dataset, - batch_size=opt.batchSize, - shuffle=not opt.serial_batches, - num_workers=int(opt.nThreads)) - - def load_data(self): - return self - - def __len__(self): - return min(len(self.dataset), self.opt.max_dataset_size) - - def __iter__(self): - for i, data in enumerate(self.dataloader): - if i >= self.opt.max_dataset_size: - break - yield data diff --git a/data/data_loader.py b/data/data_loader.py deleted file mode 100644 index 22b6a8f..0000000 --- a/data/data_loader.py +++ /dev/null @@ -1,6 +0,0 @@ -def CreateDataLoader(opt): - from data.custom_dataset_data_loader import CustomDatasetDataLoader - data_loader = CustomDatasetDataLoader() - print(data_loader.name()) - data_loader.initialize(opt) - return data_loader diff --git a/make_dataset_aligned.py b/make_dataset_aligned.py deleted file mode 100644 index 739c767..0000000 --- a/make_dataset_aligned.py +++ /dev/null @@ -1,63 +0,0 @@ -import os - -from PIL import Image - - -def get_file_paths(folder): - image_file_paths = [] - for root, dirs, filenames in os.walk(folder): - filenames = sorted(filenames) - for filename in filenames: - input_path = os.path.abspath(root) - file_path = os.path.join(input_path, filename) - if filename.endswith('.png') or filename.endswith('.jpg'): - image_file_paths.append(file_path) - - break # prevent descending into subfolders - return image_file_paths - - -def align_images(a_file_paths, b_file_paths, target_path): - if not os.path.exists(target_path): - os.makedirs(target_path) - - for i in range(len(a_file_paths)): - img_a = Image.open(a_file_paths[i]) - img_b = Image.open(b_file_paths[i]) - assert(img_a.size == img_b.size) - - aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1])) - aligned_image.paste(img_a, (0, 0)) - aligned_image.paste(img_b, (img_a.size[0], 0)) - aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i))) - - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - parser.add_argument( - '--dataset-path', - dest='dataset_path', - help='Which folder to process (it should have subfolders testA, testB, trainA and trainB' - ) - args = parser.parse_args() - - dataset_folder = args.dataset_path - print(dataset_folder) - - test_a_path = os.path.join(dataset_folder, 'testA') - test_b_path = os.path.join(dataset_folder, 'testB') - test_a_file_paths = get_file_paths(test_a_path) - test_b_file_paths = get_file_paths(test_b_path) - assert(len(test_a_file_paths) == len(test_b_file_paths)) - test_path = os.path.join(dataset_folder, 'test') - - train_a_path = os.path.join(dataset_folder, 'trainA') - train_b_path = os.path.join(dataset_folder, 'trainB') - train_a_file_paths = get_file_paths(train_a_path) - train_b_file_paths = get_file_paths(train_b_path) - assert(len(train_a_file_paths) == len(train_b_file_paths)) - train_path = os.path.join(dataset_folder, 'train') - - align_images(test_a_file_paths, test_b_file_paths, test_path) - align_images(train_a_file_paths, train_b_file_paths, train_path) diff --git a/models/__init__.py b/models/__init__.py index e69de29..681c6de 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -0,0 +1,20 @@ +def create_model(opt): + model = None + print(opt.model) + if opt.model == 'cycle_gan': + assert(opt.dataset_mode == 'unaligned') + from .cycle_gan_model import CycleGANModel + model = CycleGANModel() + elif opt.model == 'pix2pix': + assert(opt.dataset_mode == 'aligned') + from .pix2pix_model import Pix2PixModel + model = Pix2PixModel() + elif opt.model == 'test': + assert(opt.dataset_mode == 'single') + from .test_model import TestModel + model = TestModel() + else: + raise NotImplementedError('model [%s] not implemented.' % opt.model) + model.initialize(opt) + print("model [%s] was created" % (model.name())) + return model diff --git a/models/models.py b/models/models.py deleted file mode 100644 index 39cc020..0000000 --- a/models/models.py +++ /dev/null @@ -1,20 +0,0 @@ -def create_model(opt): - model = None - print(opt.model) - if opt.model == 'cycle_gan': - assert(opt.dataset_mode == 'unaligned') - from .cycle_gan_model import CycleGANModel - model = CycleGANModel() - elif opt.model == 'pix2pix': - assert(opt.dataset_mode == 'aligned') - from .pix2pix_model import Pix2PixModel - model = Pix2PixModel() - elif opt.model == 'test': - assert(opt.dataset_mode == 'single') - from .test_model import TestModel - model = TestModel() - else: - raise ValueError("Model [%s] not recognized." % opt.model) - model.initialize(opt) - print("model [%s] was created" % (model.name())) - return model @@ -1,32 +1,34 @@ import os from options.test_options import TestOptions -from data.data_loader import CreateDataLoader -from models.models import create_model +from data import CreateDataLoader +from models import create_model from util.visualizer import Visualizer from util import html -opt = TestOptions().parse() -opt.nThreads = 1 # test code only supports nThreads = 1 -opt.batchSize = 1 # test code only supports batchSize = 1 -opt.serial_batches = True # no shuffle -opt.no_flip = True # no flip -data_loader = CreateDataLoader(opt) -dataset = data_loader.load_data() -model = create_model(opt) -visualizer = Visualizer(opt) -# create website -web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) -webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) -# test -for i, data in enumerate(dataset): - if i >= opt.how_many: - break - model.set_input(data) - model.test() - visuals = model.get_current_visuals() - img_path = model.get_image_paths() - print('%04d: process image... %s' % (i, img_path)) - visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio) +if __name__ == '__main__': + opt = TestOptions().parse() + opt.nThreads = 1 # test code only supports nThreads = 1 + opt.batchSize = 1 # test code only supports batchSize = 1 + opt.serial_batches = True # no shuffle + opt.no_flip = True # no flip -webpage.save() + data_loader = CreateDataLoader(opt) + dataset = data_loader.load_data() + model = create_model(opt) + visualizer = Visualizer(opt) + # create website + web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) + webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) + # test + for i, data in enumerate(dataset): + if i >= opt.how_many: + break + model.set_input(data) + model.test() + visuals = model.get_current_visuals() + img_path = model.get_image_paths() + print('%04d: process image... %s' % (i, img_path)) + visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio) + + webpage.save() @@ -1,7 +1,7 @@ import time from options.train_options import TrainOptions -from data.data_loader import CreateDataLoader -from models.models import create_model +from data import CreateDataLoader +from models import create_model from util.visualizer import Visualizer if __name__ == '__main__': |
