summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md2
-rw-r--r--data/__init__.py54
-rw-r--r--data/custom_dataset_data_loader.py47
-rw-r--r--data/data_loader.py6
-rw-r--r--make_dataset_aligned.py63
-rw-r--r--models/__init__.py20
-rw-r--r--models/models.py20
-rw-r--r--test.py52
-rw-r--r--train.py4
9 files changed, 104 insertions, 164 deletions
diff --git a/README.md b/README.md
index 2c93bfa..eb454e7 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
# CycleGAN and pix2pix in PyTorch
-This is our ongoing PyTorch implementation for both unpaired and paired image-to-image translation.
+This is our PyTorch implementation for both unpaired and paired image-to-image translation.
The code was written by [Jun-Yan Zhu](https://github.com/junyanz) and [Taesung Park](https://github.com/taesung89).
diff --git a/data/__init__.py b/data/__init__.py
index e69de29..ef581e7 100644
--- a/data/__init__.py
+++ b/data/__init__.py
@@ -0,0 +1,54 @@
+import torch.utils.data
+from data.base_data_loader import BaseDataLoader
+
+
+def CreateDataLoader(opt):
+ data_loader = CustomDatasetDataLoader()
+ print(data_loader.name())
+ data_loader.initialize(opt)
+ return data_loader
+
+
+def CreateDataset(opt):
+ dataset = None
+ if opt.dataset_mode == 'aligned':
+ from data.aligned_dataset import AlignedDataset
+ dataset = AlignedDataset()
+ elif opt.dataset_mode == 'unaligned':
+ from data.unaligned_dataset import UnalignedDataset
+ dataset = UnalignedDataset()
+ elif opt.dataset_mode == 'single':
+ from data.single_dataset import SingleDataset
+ dataset = SingleDataset()
+ else:
+ raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
+
+ print("dataset [%s] was created" % (dataset.name()))
+ dataset.initialize(opt)
+ return dataset
+
+
+class CustomDatasetDataLoader(BaseDataLoader):
+ def name(self):
+ return 'CustomDatasetDataLoader'
+
+ def initialize(self, opt):
+ BaseDataLoader.initialize(self, opt)
+ self.dataset = CreateDataset(opt)
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batchSize,
+ shuffle=not opt.serial_batches,
+ num_workers=int(opt.nThreads))
+
+ def load_data(self):
+ return self
+
+ def __len__(self):
+ return min(len(self.dataset), self.opt.max_dataset_size)
+
+ def __iter__(self):
+ for i, data in enumerate(self.dataloader):
+ if i >= self.opt.max_dataset_size:
+ break
+ yield data
diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py
deleted file mode 100644
index 787946f..0000000
--- a/data/custom_dataset_data_loader.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import torch.utils.data
-from data.base_data_loader import BaseDataLoader
-
-
-def CreateDataset(opt):
- dataset = None
- if opt.dataset_mode == 'aligned':
- from data.aligned_dataset import AlignedDataset
- dataset = AlignedDataset()
- elif opt.dataset_mode == 'unaligned':
- from data.unaligned_dataset import UnalignedDataset
- dataset = UnalignedDataset()
- elif opt.dataset_mode == 'single':
- from data.single_dataset import SingleDataset
- dataset = SingleDataset()
- else:
- raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
-
- print("dataset [%s] was created" % (dataset.name()))
- dataset.initialize(opt)
- return dataset
-
-
-class CustomDatasetDataLoader(BaseDataLoader):
- def name(self):
- return 'CustomDatasetDataLoader'
-
- def initialize(self, opt):
- BaseDataLoader.initialize(self, opt)
- self.dataset = CreateDataset(opt)
- self.dataloader = torch.utils.data.DataLoader(
- self.dataset,
- batch_size=opt.batchSize,
- shuffle=not opt.serial_batches,
- num_workers=int(opt.nThreads))
-
- def load_data(self):
- return self
-
- def __len__(self):
- return min(len(self.dataset), self.opt.max_dataset_size)
-
- def __iter__(self):
- for i, data in enumerate(self.dataloader):
- if i >= self.opt.max_dataset_size:
- break
- yield data
diff --git a/data/data_loader.py b/data/data_loader.py
deleted file mode 100644
index 22b6a8f..0000000
--- a/data/data_loader.py
+++ /dev/null
@@ -1,6 +0,0 @@
-def CreateDataLoader(opt):
- from data.custom_dataset_data_loader import CustomDatasetDataLoader
- data_loader = CustomDatasetDataLoader()
- print(data_loader.name())
- data_loader.initialize(opt)
- return data_loader
diff --git a/make_dataset_aligned.py b/make_dataset_aligned.py
deleted file mode 100644
index 739c767..0000000
--- a/make_dataset_aligned.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import os
-
-from PIL import Image
-
-
-def get_file_paths(folder):
- image_file_paths = []
- for root, dirs, filenames in os.walk(folder):
- filenames = sorted(filenames)
- for filename in filenames:
- input_path = os.path.abspath(root)
- file_path = os.path.join(input_path, filename)
- if filename.endswith('.png') or filename.endswith('.jpg'):
- image_file_paths.append(file_path)
-
- break # prevent descending into subfolders
- return image_file_paths
-
-
-def align_images(a_file_paths, b_file_paths, target_path):
- if not os.path.exists(target_path):
- os.makedirs(target_path)
-
- for i in range(len(a_file_paths)):
- img_a = Image.open(a_file_paths[i])
- img_b = Image.open(b_file_paths[i])
- assert(img_a.size == img_b.size)
-
- aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1]))
- aligned_image.paste(img_a, (0, 0))
- aligned_image.paste(img_b, (img_a.size[0], 0))
- aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i)))
-
-
-if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--dataset-path',
- dest='dataset_path',
- help='Which folder to process (it should have subfolders testA, testB, trainA and trainB'
- )
- args = parser.parse_args()
-
- dataset_folder = args.dataset_path
- print(dataset_folder)
-
- test_a_path = os.path.join(dataset_folder, 'testA')
- test_b_path = os.path.join(dataset_folder, 'testB')
- test_a_file_paths = get_file_paths(test_a_path)
- test_b_file_paths = get_file_paths(test_b_path)
- assert(len(test_a_file_paths) == len(test_b_file_paths))
- test_path = os.path.join(dataset_folder, 'test')
-
- train_a_path = os.path.join(dataset_folder, 'trainA')
- train_b_path = os.path.join(dataset_folder, 'trainB')
- train_a_file_paths = get_file_paths(train_a_path)
- train_b_file_paths = get_file_paths(train_b_path)
- assert(len(train_a_file_paths) == len(train_b_file_paths))
- train_path = os.path.join(dataset_folder, 'train')
-
- align_images(test_a_file_paths, test_b_file_paths, test_path)
- align_images(train_a_file_paths, train_b_file_paths, train_path)
diff --git a/models/__init__.py b/models/__init__.py
index e69de29..681c6de 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -0,0 +1,20 @@
+def create_model(opt):
+ model = None
+ print(opt.model)
+ if opt.model == 'cycle_gan':
+ assert(opt.dataset_mode == 'unaligned')
+ from .cycle_gan_model import CycleGANModel
+ model = CycleGANModel()
+ elif opt.model == 'pix2pix':
+ assert(opt.dataset_mode == 'aligned')
+ from .pix2pix_model import Pix2PixModel
+ model = Pix2PixModel()
+ elif opt.model == 'test':
+ assert(opt.dataset_mode == 'single')
+ from .test_model import TestModel
+ model = TestModel()
+ else:
+ raise NotImplementedError('model [%s] not implemented.' % opt.model)
+ model.initialize(opt)
+ print("model [%s] was created" % (model.name()))
+ return model
diff --git a/models/models.py b/models/models.py
deleted file mode 100644
index 39cc020..0000000
--- a/models/models.py
+++ /dev/null
@@ -1,20 +0,0 @@
-def create_model(opt):
- model = None
- print(opt.model)
- if opt.model == 'cycle_gan':
- assert(opt.dataset_mode == 'unaligned')
- from .cycle_gan_model import CycleGANModel
- model = CycleGANModel()
- elif opt.model == 'pix2pix':
- assert(opt.dataset_mode == 'aligned')
- from .pix2pix_model import Pix2PixModel
- model = Pix2PixModel()
- elif opt.model == 'test':
- assert(opt.dataset_mode == 'single')
- from .test_model import TestModel
- model = TestModel()
- else:
- raise ValueError("Model [%s] not recognized." % opt.model)
- model.initialize(opt)
- print("model [%s] was created" % (model.name()))
- return model
diff --git a/test.py b/test.py
index 863e550..8444bd9 100644
--- a/test.py
+++ b/test.py
@@ -1,32 +1,34 @@
import os
from options.test_options import TestOptions
-from data.data_loader import CreateDataLoader
-from models.models import create_model
+from data import CreateDataLoader
+from models import create_model
from util.visualizer import Visualizer
from util import html
-opt = TestOptions().parse()
-opt.nThreads = 1 # test code only supports nThreads = 1
-opt.batchSize = 1 # test code only supports batchSize = 1
-opt.serial_batches = True # no shuffle
-opt.no_flip = True # no flip
-data_loader = CreateDataLoader(opt)
-dataset = data_loader.load_data()
-model = create_model(opt)
-visualizer = Visualizer(opt)
-# create website
-web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
-webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
-# test
-for i, data in enumerate(dataset):
- if i >= opt.how_many:
- break
- model.set_input(data)
- model.test()
- visuals = model.get_current_visuals()
- img_path = model.get_image_paths()
- print('%04d: process image... %s' % (i, img_path))
- visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio)
+if __name__ == '__main__':
+ opt = TestOptions().parse()
+ opt.nThreads = 1 # test code only supports nThreads = 1
+ opt.batchSize = 1 # test code only supports batchSize = 1
+ opt.serial_batches = True # no shuffle
+ opt.no_flip = True # no flip
-webpage.save()
+ data_loader = CreateDataLoader(opt)
+ dataset = data_loader.load_data()
+ model = create_model(opt)
+ visualizer = Visualizer(opt)
+ # create website
+ web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
+ webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
+ # test
+ for i, data in enumerate(dataset):
+ if i >= opt.how_many:
+ break
+ model.set_input(data)
+ model.test()
+ visuals = model.get_current_visuals()
+ img_path = model.get_image_paths()
+ print('%04d: process image... %s' % (i, img_path))
+ visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio)
+
+ webpage.save()
diff --git a/train.py b/train.py
index ee8cff1..dd79bab 100644
--- a/train.py
+++ b/train.py
@@ -1,7 +1,7 @@
import time
from options.train_options import TrainOptions
-from data.data_loader import CreateDataLoader
-from models.models import create_model
+from data import CreateDataLoader
+from models import create_model
from util.visualizer import Visualizer
if __name__ == '__main__':