summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-04-20 04:37:12 -0700
committerjunyanz <junyanz@berkeley.edu>2017-04-20 04:37:12 -0700
commit2d4b9bf9149126d20237dc357a8c23294273e91b (patch)
tree2df00ed9941fcfec08d478500ba473f16a0965ee
parentdee4a6844d464252f198e3a64ab7e919d5ded13a (diff)
make visdom default & add no_html & update scripts
-rw-r--r--README.md21
-rw-r--r--options/train_options.py8
-rw-r--r--scripts/test_cyclegan.sh1
-rw-r--r--scripts/test_pix2pix.sh2
-rw-r--r--scripts/train_cyclegan.sh2
-rw-r--r--scripts/train_pix2pix.sh2
-rw-r--r--util/visualizer.py11
7 files changed, 26 insertions, 21 deletions
diff --git a/README.md b/README.md
index 934e4bd..42ddcd3 100644
--- a/README.md
+++ b/README.md
@@ -60,12 +60,12 @@ cd pytorch-CycleGAN-and-pix2pix
```bash
bash ./datasets/download_cyclegan_dataset.sh maps
```
-- Train a model:
+- Train a model (`bash ./scripts/train_cyclegan.sh`):
```bash
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
```
-To view results as the model trains, check out the html file `./checkpoints/maps_cyclegan/web/index.html`
-- Test the model:
+- To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/maps_cyclegan/web/index.html`
+- Test the model (`bash ./scripts/test_cyclegan.sh`):
```bash
python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --phase test
```
@@ -76,14 +76,14 @@ The test results will be saved to a html file here: `./results/maps_cyclegan/lat
```bash
bash ./datasets/download_pix2pix_dataset.sh facades
```
-- Train a model:
+- Train a model (`bash ./scripts/train_pix2pix.sh`):
```bash
-python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --align_data --which_direction BtoA
+python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --align_data --no_lsgan
```
-To view results as the model trains, check out the html file `./checkpoints/facades_pix2pix/web/index.html`
-- Test the model:
+- To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/facades_pix2pix/web/index.html`
+- Test the model (`bash ./scripts/test_pix2pix.sh`):
```bash
-python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --align_data --which_direction BtoA
+python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data
```
The test results will be saved to a html file here: `./results/facades_pix2pix/latest_val/index.html`.
@@ -92,7 +92,7 @@ More example scripts can be found at `scripts` directory.
## Training/test Details
- See `options/train_options.py` and `options/base_options.py` for training flags; see `optoins/test_options.py` and `options/base_options.py` for test flags.
- CPU/GPU (default `--gpu_ids 0`): Set `--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode.
-- During training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results will be shown on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have visdom installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id 0`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`.
+- During training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will be shown on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have visdom installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id 0`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`.
### CycleGAN Datasets
@@ -139,7 +139,8 @@ This will combine each pair of images (A,B) into a single image file, ready for
## TODO
- add reflection and other padding layers.
-- add one-direction test model.
+- add one-direction test mode for CycleGAN.
+- add more preprocessing options.
- fully test Unet architecture.
- fully test instance normalization layer from [fast-neural-style project](https://github.com/darkstar112358/fast-neural-style).
- fully test CPU mode and multi-GPU mode.
diff --git a/options/train_options.py b/options/train_options.py
index e7d4b3a..b241863 100644
--- a/options/train_options.py
+++ b/options/train_options.py
@@ -7,8 +7,7 @@ class TrainOptions(BaseOptions):
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
- self.parser.add_argument('--save_display_freq', type=int, default=2500, help='save the current display of results every save_display_freq_iterations')
- self.parser.add_argument('--continue_train', action='store_true', help='if continue training, load the latest model')
+ self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
@@ -16,9 +15,10 @@ class TrainOptions(BaseOptions):
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
self.parser.add_argument('--ntrain', type=int, default=float("inf"), help='# of examples per epoch.')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
- self.parser.add_argument('--no_lsgan', action='store_true', help='if true, do *not* use least square GAN, if false, use vanilla GAN')
+ self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
- self.parser.add_argument('--preprocessing', type=str, default='resize_and_crop', help='resizing/cropping strategy')
+ self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
+ # NOT-IMPLEMENTED self.parser.add_argument('--preprocessing', type=str, default='resize_and_crop', help='resizing/cropping strategy')
self.isTrain = True
diff --git a/scripts/test_cyclegan.sh b/scripts/test_cyclegan.sh
new file mode 100644
index 0000000..714267e
--- /dev/null
+++ b/scripts/test_cyclegan.sh
@@ -0,0 +1 @@
+python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --phase test
diff --git a/scripts/test_pix2pix.sh b/scripts/test_pix2pix.sh
index f46b262..d5c2960 100644
--- a/scripts/test_pix2pix.sh
+++ b/scripts/test_pix2pix.sh
@@ -1 +1 @@
-python test.py --dataroot=./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --align_data --which_direction BtoA
+python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data
diff --git a/scripts/train_cyclegan.sh b/scripts/train_cyclegan.sh
index a3dd29b..9da219f 100644
--- a/scripts/train_cyclegan.sh
+++ b/scripts/train_cyclegan.sh
@@ -1 +1 @@
-python train.py --dataroot=./datasets/maps --name maps_cyclegan --save_epoch_freq 5
+python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
diff --git a/scripts/train_pix2pix.sh b/scripts/train_pix2pix.sh
index b96517b..188050b 100644
--- a/scripts/train_pix2pix.sh
+++ b/scripts/train_pix2pix.sh
@@ -1 +1 @@
-python train.py --dataroot=./datasets/facades --name facades_pix2pix --which_model_netG unet_256 --loadSize 286 --fineSize 256 --model pix2pix --align_data --which_direction BtoA --save_epoch_freq 25
+python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --align_data --no_lsgan
diff --git a/util/visualizer.py b/util/visualizer.py
index d5e7083..bca74b2 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -3,17 +3,19 @@ import os
import ntpath
import time
from . import util
+from . import html
class Visualizer():
def __init__(self, opt):
# self.opt = opt
self.display_id = opt.display_id
+ self.use_html = not opt.no_html
self.name = opt.name
if self.display_id > 0:
import visdom
self.vis = visdom.Visdom()
- else:
- from . import html
+
+ if self.use_html:
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
self.win_size = opt.display_winsize
@@ -30,7 +32,8 @@ class Visualizer():
self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
- else: # save images to a web directory
+
+ if self.use_html: # save images to a html file
for label, image_numpy in visuals.items():
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
@@ -49,7 +52,7 @@ class Visualizer():
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
-
+
# errors: dictionary of error labels and values
def plot_current_errors(self, epoch, counter_ratio, opt, errors):
if not hasattr(self, 'plot_data'):