summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md8
-rw-r--r--data/aligned_data_loader.py4
-rw-r--r--data/unaligned_data_loader.py32
-rw-r--r--scripts/test_pix2pix.sh2
-rw-r--r--scripts/train_cyclegan.sh2
-rw-r--r--scripts/train_pix2pix.sh2
-rw-r--r--util/visualizer.py4
7 files changed, 35 insertions, 19 deletions
diff --git a/README.md b/README.md
index 28631f3..cd79db9 100644
--- a/README.md
+++ b/README.md
@@ -44,7 +44,7 @@ In CVPR 2017.
## Getting Started
### Installation
- Install PyTorch and dependencies from http://pytorch.org/
-- Install python libraries [dominate](https://github.com/Knio/dominate).
+- Install python libraries [dominate](https://github.com/Knio/dominate) and [visdom](https://github.com/facebookresearch/visdom) (optional).
- Clone this repo:
```bash
git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
@@ -74,12 +74,12 @@ bash ./datasets/download_pix2pix_dataset.sh facades
```
- Train a model:
```bash
-python train.py --dataroot ./datasets/facades --name facades_pix2pix --gpu_ids 0 --model pix2pix --align_data --which_direction BtoA
+python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --align_data --which_direction BtoA
```
To view results as the model trains, check out the html file `./checkpoints/facades_pix2pix/web/index.html`
- Test the model:
```bash
-python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --phase val --align_data --which_direction BtoA
+python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --align_data --which_direction BtoA
```
The test results will be saved to a html file here: `./results/facades_pix2pix/latest_val/index.html`.
@@ -87,7 +87,7 @@ More example scripts can be found at `scripts` directory.
## Training/test Details
- See `options/train_options.py` and `options/base_options.py` for training flags; see `optoins/test_options.py` and `options/base_options.py` for test flags.
-- CPU/GPU: Set `--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode.
+- CPU/GPU (default `--gpu_ids 0`): Set `--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode.
- During training, you can visualize the result of current training. If you set `--display_id 0`, we will periodically save the training results to `[opt.checkpoints_dir]/[opt.name]/web/`. If you set `--display_id` > 0, the results will be shown on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should visdom installed. You need to invoke `python -m visdom.server` to start the server. The default server URL is `http://localhost:8097`
### CycleGAN Datasets
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
index 01dbf89..bea3531 100644
--- a/data/aligned_data_loader.py
+++ b/data/aligned_data_loader.py
@@ -17,9 +17,7 @@ class PairedData(object):
return self
def __next__(self):
- # st()
AB, AB_paths = next(self.data_loader_iter)
- # st()
w_total = AB.size(3)
w = int(w_total / 2)
h = AB.size(2)
@@ -55,7 +53,7 @@ class AlignedDataLoader(BaseDataLoader):
batch_size=self.opt.batchSize,
shuffle=not self.opt.serial_batches,
num_workers=int(self.opt.nThreads))
-
+
self.dataset = dataset
self.paired_data = PairedData(data_loader, opt.fineSize)
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 95d9ac7..4f82dbe 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -3,12 +3,14 @@ import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
from data.image_folder import ImageFolder
from builtins import object
-
+from pdb import set_trace as st
class PairedData(object):
def __init__(self, data_loader_A, data_loader_B):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
+ self.stop_A = False
+ self.stop_B = False
def __iter__(self):
self.data_loader_A_iter = iter(self.data_loader_A)
@@ -16,11 +18,29 @@ class PairedData(object):
return self
def __next__(self):
- A, A_paths = next(self.data_loader_A_iter)
- B, B_paths = next(self.data_loader_B_iter)
- return {'A': A, 'A_paths': A_paths,
- 'B': B, 'B_paths': B_paths}
+ A, A_paths = None, None
+ B, B_paths = None, None
+ try:
+ A, A_paths = next(self.data_loader_A_iter)
+ except StopIteration:
+ if A is None or A_paths is None:
+ self.stop_A = True
+ self.data_loader_A_iter = iter(self.data_loader_A)
+ A, A_paths = next(self.data_loader_A_iter)
+ try:
+ B, B_paths = next(self.data_loader_B_iter)
+
+ except StopIteration:
+ if B is None or B_paths is None:
+ self.stop_B = True
+ self.data_loader_B_iter = iter(self.data_loader_B)
+ B, B_paths = next(self.data_loader_B_iter)
+ if self.stop_A and self.stop_B:
+ raise StopIteration()
+ else:
+ return {'A': A, 'A_paths': A_paths,
+ 'B': B, 'B_paths': B_paths}
class UnalignedDataLoader(BaseDataLoader):
def initialize(self, opt):
@@ -60,4 +80,4 @@ class UnalignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return len(self.dataset_A)
+ return max(len(self.dataset_A), len(self.dataset_B))
diff --git a/scripts/test_pix2pix.sh b/scripts/test_pix2pix.sh
index 7b056fe..f46b262 100644
--- a/scripts/test_pix2pix.sh
+++ b/scripts/test_pix2pix.sh
@@ -1 +1 @@
-python test.py --dataroot=./datasets/facades --name facades_pix2pix --model pix2pix --align_data
+python test.py --dataroot=./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --align_data --which_direction BtoA
diff --git a/scripts/train_cyclegan.sh b/scripts/train_cyclegan.sh
index 03f7fd9..a3dd29b 100644
--- a/scripts/train_cyclegan.sh
+++ b/scripts/train_cyclegan.sh
@@ -1 +1 @@
-python train.py --dataroot=./datasets/horse2zebra --name horse2zebra_cyclegan --gpu_ids 0 --save_epoch_freq 5
+python train.py --dataroot=./datasets/maps --name maps_cyclegan --save_epoch_freq 5
diff --git a/scripts/train_pix2pix.sh b/scripts/train_pix2pix.sh
index 682c6c6..b96517b 100644
--- a/scripts/train_pix2pix.sh
+++ b/scripts/train_pix2pix.sh
@@ -1 +1 @@
-python train.py --dataroot=./datasets/facades --name facades_pix2pix --which_model_netG resnet_9blocks --loadSize 286 --fineSize 256 --model pix2pix --align_data --which_direction BtoA --save_epoch_freq 25
+python train.py --dataroot=./datasets/facades --name facades_pix2pix --which_model_netG unet_256 --loadSize 286 --fineSize 256 --model pix2pix --align_data --which_direction BtoA --save_epoch_freq 25
diff --git a/util/visualizer.py b/util/visualizer.py
index a839896..d5e7083 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -3,8 +3,6 @@ import os
import ntpath
import time
from . import util
-from . import html
-from pdb import set_trace as st
class Visualizer():
def __init__(self, opt):
@@ -51,7 +49,7 @@ class Visualizer():
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
- # st()
+
# errors: dictionary of error labels and values
def plot_current_errors(self, epoch, counter_ratio, opt, errors):
if not hasattr(self, 'plot_data'):