diff options
| -rw-r--r-- | README.md | 9 | ||||
| -rw-r--r-- | data/aligned_data_loader.py | 4 | ||||
| -rw-r--r-- | data/unaligned_data_loader.py | 32 | ||||
| -rw-r--r-- | scripts/test_pix2pix.sh | 2 | ||||
| -rw-r--r-- | scripts/train_cyclegan.sh | 2 | ||||
| -rw-r--r-- | scripts/train_pix2pix.sh | 2 | ||||
| -rw-r--r-- | util/visualizer.py | 4 |
7 files changed, 36 insertions, 19 deletions
@@ -44,7 +44,7 @@ In CVPR 2017. ## Getting Started ### Installation - Install PyTorch and dependencies from http://pytorch.org/ -- Install python libraries [dominate](https://github.com/Knio/dominate). +- Install python libraries [dominate](https://github.com/Knio/dominate) and [visdom](https://github.com/facebookresearch/visdom) (optional). - Clone this repo: ```bash git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix @@ -74,12 +74,12 @@ bash ./datasets/download_pix2pix_dataset.sh facades ``` - Train a model: ```bash -python train.py --dataroot ./datasets/facades --name facades_pix2pix --gpu_ids 0 --model pix2pix --align_data --which_direction BtoA +python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --align_data --which_direction BtoA ``` To view results as the model trains, check out the html file `./checkpoints/facades_pix2pix/web/index.html` - Test the model: ```bash -python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --phase val --align_data --which_direction BtoA +python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --align_data --which_direction BtoA ``` The test results will be saved to a html file here: `./results/facades_pix2pix/latest_val/index.html`. @@ -87,9 +87,10 @@ More example scripts can be found at `scripts` directory. ## Training/test Details - See `options/train_options.py` and `options/base_options.py` for training flags; see `optoins/test_options.py` and `options/base_options.py` for test flags. -- CPU/GPU: Set `--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. +- CPU/GPU (default `--gpu_ids 0`): Set `--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. - During training, you can visualize the result of current training. If you set `--display_id 0`, we will periodically save the training results to `[opt.checkpoints_dir]/[opt.name]/web/`. If you set `--display_id` > 0, the results will be shown on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should visdom installed. You need to invoke `python -m visdom.server` to start the server. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. + ### CycleGAN Datasets Download the CycleGAN datasets using the following script: ```bash diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py index 01dbf89..bea3531 100644 --- a/data/aligned_data_loader.py +++ b/data/aligned_data_loader.py @@ -17,9 +17,7 @@ class PairedData(object): return self def __next__(self): - # st() AB, AB_paths = next(self.data_loader_iter) - # st() w_total = AB.size(3) w = int(w_total / 2) h = AB.size(2) @@ -55,7 +53,7 @@ class AlignedDataLoader(BaseDataLoader): batch_size=self.opt.batchSize, shuffle=not self.opt.serial_batches, num_workers=int(self.opt.nThreads)) - + self.dataset = dataset self.paired_data = PairedData(data_loader, opt.fineSize) diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py index 95d9ac7..4f82dbe 100644 --- a/data/unaligned_data_loader.py +++ b/data/unaligned_data_loader.py @@ -3,12 +3,14 @@ import torchvision.transforms as transforms from data.base_data_loader import BaseDataLoader from data.image_folder import ImageFolder from builtins import object - +from pdb import set_trace as st class PairedData(object): def __init__(self, data_loader_A, data_loader_B): self.data_loader_A = data_loader_A self.data_loader_B = data_loader_B + self.stop_A = False + self.stop_B = False def __iter__(self): self.data_loader_A_iter = iter(self.data_loader_A) @@ -16,11 +18,29 @@ class PairedData(object): return self def __next__(self): - A, A_paths = next(self.data_loader_A_iter) - B, B_paths = next(self.data_loader_B_iter) - return {'A': A, 'A_paths': A_paths, - 'B': B, 'B_paths': B_paths} + A, A_paths = None, None + B, B_paths = None, None + try: + A, A_paths = next(self.data_loader_A_iter) + except StopIteration: + if A is None or A_paths is None: + self.stop_A = True + self.data_loader_A_iter = iter(self.data_loader_A) + A, A_paths = next(self.data_loader_A_iter) + try: + B, B_paths = next(self.data_loader_B_iter) + + except StopIteration: + if B is None or B_paths is None: + self.stop_B = True + self.data_loader_B_iter = iter(self.data_loader_B) + B, B_paths = next(self.data_loader_B_iter) + if self.stop_A and self.stop_B: + raise StopIteration() + else: + return {'A': A, 'A_paths': A_paths, + 'B': B, 'B_paths': B_paths} class UnalignedDataLoader(BaseDataLoader): def initialize(self, opt): @@ -60,4 +80,4 @@ class UnalignedDataLoader(BaseDataLoader): return self.paired_data def __len__(self): - return len(self.dataset_A) + return max(len(self.dataset_A), len(self.dataset_B)) diff --git a/scripts/test_pix2pix.sh b/scripts/test_pix2pix.sh index 7b056fe..f46b262 100644 --- a/scripts/test_pix2pix.sh +++ b/scripts/test_pix2pix.sh @@ -1 +1 @@ -python test.py --dataroot=./datasets/facades --name facades_pix2pix --model pix2pix --align_data +python test.py --dataroot=./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --align_data --which_direction BtoA diff --git a/scripts/train_cyclegan.sh b/scripts/train_cyclegan.sh index 03f7fd9..a3dd29b 100644 --- a/scripts/train_cyclegan.sh +++ b/scripts/train_cyclegan.sh @@ -1 +1 @@ -python train.py --dataroot=./datasets/horse2zebra --name horse2zebra_cyclegan --gpu_ids 0 --save_epoch_freq 5 +python train.py --dataroot=./datasets/maps --name maps_cyclegan --save_epoch_freq 5 diff --git a/scripts/train_pix2pix.sh b/scripts/train_pix2pix.sh index 682c6c6..b96517b 100644 --- a/scripts/train_pix2pix.sh +++ b/scripts/train_pix2pix.sh @@ -1 +1 @@ -python train.py --dataroot=./datasets/facades --name facades_pix2pix --which_model_netG resnet_9blocks --loadSize 286 --fineSize 256 --model pix2pix --align_data --which_direction BtoA --save_epoch_freq 25 +python train.py --dataroot=./datasets/facades --name facades_pix2pix --which_model_netG unet_256 --loadSize 286 --fineSize 256 --model pix2pix --align_data --which_direction BtoA --save_epoch_freq 25 diff --git a/util/visualizer.py b/util/visualizer.py index a839896..d5e7083 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -3,8 +3,6 @@ import os import ntpath import time from . import util -from . import html -from pdb import set_trace as st class Visualizer(): def __init__(self, opt): @@ -51,7 +49,7 @@ class Visualizer(): links.append(img_path) webpage.add_images(ims, txts, links, width=self.win_size) webpage.save() - # st() + # errors: dictionary of error labels and values def plot_current_errors(self, epoch, counter_ratio, opt, errors): if not hasattr(self, 'plot_data'): |
