summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--environment.yml14
-rw-r--r--make_dataset_aligned.py63
-rw-r--r--train.py85
4 files changed, 121 insertions, 42 deletions
diff --git a/.gitignore b/.gitignore
index faba6c9..4fdef3e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -40,3 +40,4 @@ test/.coverage
*/**/*.dylib*
test/data/legacy_serialized.pt
*~
+.idea
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000..116d052
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,14 @@
+name: pytorch-CycleGAN-and-pix2pix
+channels:
+- peterjc123
+- defaults
+dependencies:
+- python=3.5.5
+- pytorch=0.3.1
+- scipy
+- pip:
+ - dominate==2.3.1
+ - git+https://github.com/pytorch/vision.git
+ - Pillow==5.0.0
+ - numpy==1.14.1
+ - visdom==0.1.7
diff --git a/make_dataset_aligned.py b/make_dataset_aligned.py
new file mode 100644
index 0000000..739c767
--- /dev/null
+++ b/make_dataset_aligned.py
@@ -0,0 +1,63 @@
+import os
+
+from PIL import Image
+
+
+def get_file_paths(folder):
+ image_file_paths = []
+ for root, dirs, filenames in os.walk(folder):
+ filenames = sorted(filenames)
+ for filename in filenames:
+ input_path = os.path.abspath(root)
+ file_path = os.path.join(input_path, filename)
+ if filename.endswith('.png') or filename.endswith('.jpg'):
+ image_file_paths.append(file_path)
+
+ break # prevent descending into subfolders
+ return image_file_paths
+
+
+def align_images(a_file_paths, b_file_paths, target_path):
+ if not os.path.exists(target_path):
+ os.makedirs(target_path)
+
+ for i in range(len(a_file_paths)):
+ img_a = Image.open(a_file_paths[i])
+ img_b = Image.open(b_file_paths[i])
+ assert(img_a.size == img_b.size)
+
+ aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1]))
+ aligned_image.paste(img_a, (0, 0))
+ aligned_image.paste(img_b, (img_a.size[0], 0))
+ aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i)))
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--dataset-path',
+ dest='dataset_path',
+ help='Which folder to process (it should have subfolders testA, testB, trainA and trainB'
+ )
+ args = parser.parse_args()
+
+ dataset_folder = args.dataset_path
+ print(dataset_folder)
+
+ test_a_path = os.path.join(dataset_folder, 'testA')
+ test_b_path = os.path.join(dataset_folder, 'testB')
+ test_a_file_paths = get_file_paths(test_a_path)
+ test_b_file_paths = get_file_paths(test_b_path)
+ assert(len(test_a_file_paths) == len(test_b_file_paths))
+ test_path = os.path.join(dataset_folder, 'test')
+
+ train_a_path = os.path.join(dataset_folder, 'trainA')
+ train_b_path = os.path.join(dataset_folder, 'trainB')
+ train_a_file_paths = get_file_paths(train_a_path)
+ train_b_file_paths = get_file_paths(train_b_path)
+ assert(len(train_a_file_paths) == len(train_b_file_paths))
+ train_path = os.path.join(dataset_folder, 'train')
+
+ align_images(test_a_file_paths, test_b_file_paths, test_path)
+ align_images(train_a_file_paths, train_b_file_paths, train_path)
diff --git a/train.py b/train.py
index 61b596a..ee8cff1 100644
--- a/train.py
+++ b/train.py
@@ -4,54 +4,55 @@ from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
-opt = TrainOptions().parse()
-data_loader = CreateDataLoader(opt)
-dataset = data_loader.load_data()
-dataset_size = len(data_loader)
-print('#training images = %d' % dataset_size)
+if __name__ == '__main__':
+ opt = TrainOptions().parse()
+ data_loader = CreateDataLoader(opt)
+ dataset = data_loader.load_data()
+ dataset_size = len(data_loader)
+ print('#training images = %d' % dataset_size)
-model = create_model(opt)
-visualizer = Visualizer(opt)
-total_steps = 0
+ model = create_model(opt)
+ visualizer = Visualizer(opt)
+ total_steps = 0
-for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
- epoch_start_time = time.time()
- iter_data_time = time.time()
- epoch_iter = 0
+ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
+ epoch_start_time = time.time()
+ iter_data_time = time.time()
+ epoch_iter = 0
+
+ for i, data in enumerate(dataset):
+ iter_start_time = time.time()
+ if total_steps % opt.print_freq == 0:
+ t_data = iter_start_time - iter_data_time
+ visualizer.reset()
+ total_steps += opt.batchSize
+ epoch_iter += opt.batchSize
+ model.set_input(data)
+ model.optimize_parameters()
- for i, data in enumerate(dataset):
- iter_start_time = time.time()
- if total_steps % opt.print_freq == 0:
- t_data = iter_start_time - iter_data_time
- visualizer.reset()
- total_steps += opt.batchSize
- epoch_iter += opt.batchSize
- model.set_input(data)
- model.optimize_parameters()
+ if total_steps % opt.display_freq == 0:
+ save_result = total_steps % opt.update_html_freq == 0
+ visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
- if total_steps % opt.display_freq == 0:
- save_result = total_steps % opt.update_html_freq == 0
- visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
+ if total_steps % opt.print_freq == 0:
+ errors = model.get_current_errors()
+ t = (time.time() - iter_start_time) / opt.batchSize
+ visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data)
+ if opt.display_id > 0:
+ visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)
- if total_steps % opt.print_freq == 0:
- errors = model.get_current_errors()
- t = (time.time() - iter_start_time) / opt.batchSize
- visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data)
- if opt.display_id > 0:
- visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)
+ if total_steps % opt.save_latest_freq == 0:
+ print('saving the latest model (epoch %d, total_steps %d)' %
+ (epoch, total_steps))
+ model.save('latest')
- if total_steps % opt.save_latest_freq == 0:
- print('saving the latest model (epoch %d, total_steps %d)' %
+ iter_data_time = time.time()
+ if epoch % opt.save_epoch_freq == 0:
+ print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))
model.save('latest')
+ model.save(epoch)
- iter_data_time = time.time()
- if epoch % opt.save_epoch_freq == 0:
- print('saving the model at the end of epoch %d, iters %d' %
- (epoch, total_steps))
- model.save('latest')
- model.save(epoch)
-
- print('End of epoch %d / %d \t Time Taken: %d sec' %
- (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
- model.update_learning_rate()
+ print('End of epoch %d / %d \t Time Taken: %d sec' %
+ (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
+ model.update_learning_rate()