summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
committerjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
commitc99ce7c4e781712e0252c6127ad1a4e8021cc489 (patch)
treeba99dfd56a47036d9c1f18620abf4efc248839ab
first commit
-rw-r--r--.gitignore40
-rw-r--r--LICENSE58
-rw-r--r--README.md143
-rw-r--r--data/__init__.py0
-rw-r--r--data/aligned_data_loader.py69
-rw-r--r--data/base_data_loader.py14
-rw-r--r--data/data_loader.py12
-rw-r--r--data/image_folder.py67
-rw-r--r--data/unaligned_data_loader.py63
-rw-r--r--datasets/combine_A_and_B.py49
-rw-r--r--datasets/download_cyclegan_dataset.sh14
-rw-r--r--datasets/download_pix2pix_dataset.sh8
-rw-r--r--models/__init__.py0
-rw-r--r--models/base_model.py56
-rw-r--r--models/cycle_gan_model.py222
-rw-r--r--models/models.py15
-rw-r--r--models/networks.py288
-rw-r--r--models/pix2pix_model.py147
-rw-r--r--options/__init__.py0
-rw-r--r--options/base_options.py71
-rw-r--r--options/test_options.py12
-rw-r--r--options/train_options.py24
-rw-r--r--scripts/test_pix2pix.sh1
-rw-r--r--scripts/train_cyclegan.sh1
-rw-r--r--scripts/train_pix2pix.sh1
-rw-r--r--test.py34
-rw-r--r--train.py52
-rw-r--r--util/__init__.py0
-rw-r--r--util/display.py115
-rw-r--r--util/html.py64
-rw-r--r--util/image_pool.py33
-rw-r--r--util/png.py33
-rw-r--r--util/util.py71
-rw-r--r--util/visualizer.py86
34 files changed, 1863 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..4a26633
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,40 @@
+datasets/
+checkpoints/
+results/
+build/
+dist/
+*.png
+torch.egg-info/
+*/**/__pycache__
+torch/version.py
+torch/csrc/generic/TensorMethods.cpp
+torch/lib/*.so*
+torch/lib/*.dylib*
+torch/lib/*.h
+torch/lib/build
+torch/lib/tmp_install
+torch/lib/include
+torch/lib/torch_shm_manager
+torch/csrc/cudnn/cuDNN.cpp
+torch/csrc/nn/THNN.cwrap
+torch/csrc/nn/THNN.cpp
+torch/csrc/nn/THCUNN.cwrap
+torch/csrc/nn/THCUNN.cpp
+torch/csrc/nn/THNN_generic.cwrap
+torch/csrc/nn/THNN_generic.cpp
+torch/csrc/nn/THNN_generic.h
+docs/src/**/*
+test/data/legacy_modules.t7
+test/data/gpu_tensors.pt
+test/htmlcov
+test/.coverage
+*/*.pyc
+*/**/*.pyc
+*/**/**/*.pyc
+*/**/**/**/*.pyc
+*/**/**/**/**/*.pyc
+*/*.so*
+*/**/*.so*
+*/**/*.dylib*
+test/data/legacy_serialized.pt
+*~
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d75f0ee
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,58 @@
+Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+--------------------------- LICENSE FOR pix2pix --------------------------------
+BSD License
+
+For pix2pix software
+Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+----------------------------- LICENSE FOR DCGAN --------------------------------
+BSD License
+
+For dcgan.torch software
+
+Copyright (c) 2015, Facebook, Inc. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+
+Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+
+Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+
+Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..a02001d
--- /dev/null
+++ b/README.md
@@ -0,0 +1,143 @@
+<img src='imgs/horse2zebra.gif' align="right" width=384>
+
+<br><br><br>
+
+# CycleGAN and pix2pix in PyTorch
+
+This is our ongoing PyTorch implementation for both unpaired and paired image-to-image translation. Check out the original [CycleGAN Torch](https://github.com/junyanz/CycleGAN) and [pix2pix Torch](https://github.com/phillipi/pix2pix) if you would like to reproduce the exact results in the paper. The code was written by [Jun-Yan Zhu](https://github.com/junyanz) and [Taesung Park](https://github.com/taesung89).
+
+
+### CycleGAN: [[Project]](https://junyanz.github.io/CycleGAN/) [[Paper]](https://arxiv.org/pdf/1703.10593.pdf) [[Torch]](https://github.com/junyanz/CycleGAN)
+<img src="https://junyanz.github.io/CycleGAN/images/teaser_high_res.jpg" width="1000px"/>
+
+### Pix2pix: [[Project]](https://phillipi.github.io/pix2pix/) [[Paper]](https://arxiv.org/pdf/1611.07004v1.pdf) [[Torch]](https://github.com/phillipi/pix2pix)
+
+<img src="https://phillipi.github.io/pix2pix/images/teaser_v3.png" width="1000px"/>
+
+If you use this code for your research, please cite:
+
+Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
+[Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*, [Taesung Park](https://taesung.me/)\*, [Phillip Isola](http://web.mit.edu/phillipi/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/)
+In arxiv, 2017. (* equal contributions)
+
+
+Image-to-Image Translation Using Conditional Adversarial Networks
+[Phillip Isola](http://web.mit.edu/phillipi/), [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/), [Tinghui Zhou](https://people.eecs.berkeley.edu/~tinghuiz/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/)
+In CVPR 2017.
+
+
+
+## Prerequisites
+- Linux or OSX.
+- Python 2 and 3.
+- CPU or NVIDIA GPU + CUDA CuDNN.
+
+## Getting Started
+### Installation
+- Install PyTorch and dependencies from http://pytorch.org/
+- Install python libraries [dominate](https://github.com/Knio/dominate).
+- Clone this repo:
+```bash
+git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
+cd pytorch-CycleGAN-and-pix2pix
+```
+
+### CycleGAN train/test
+- Download a CycleGAN dataset (e.g. maps):
+```bash
+bash ./datasets/download_cyclegan_dataset.sh maps
+```
+- Train a model:
+```bash
+python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
+```
+To view results as the model trains, check out the html file `./checkpoints/maps_cyclegan/web/index.html`
+- Test the model:
+```bash
+python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --phase test
+```
+The test results will be saved to a html file here: `./results/maps_cyclegan/latest_test/index.html`.
+
+### pix2pix train/test
+- Download a pix2pix dataset (e.g.facades):
+```bash
+bash ./datasets/download_pix2pix_dataset.sh facades
+```
+- Train a model:
+```bash
+python train.py --dataroot ./datasets/facades --name facades_pix2pix --gpu_ids 0 --model pix2pix --align_data --which_direction BtoA
+```
+To view results as the model trains, check out the html file `./checkpoints/facades_pix2pix/web/index.html`
+- Test the model:
+```bash
+python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --phase val --align_data --which_direction BtoA
+```
+The test results will be saved to a html file here: `./results/facades_pix2pix/latest_val/index.html`.
+
+More example scripts can be found at `scripts` directory.
+
+
+## Training/test Details
+- See `options/train_options.py` and `options/base_options.py` for training flags; see `optoins/test_options.py` and `options/base_options.py` for test flags.
+- CPU/GPU: Set `--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode.
+- If you set `--display_id 0`, we will save the training results to `../checkpoints/name/web/index.html`. If you set `--display_id` > 0, we will use a browser-based graphics server. You need to call `th -ldisplay.start 8000 0.0.0.0` to start the server. See [[szym/display]](https://github.com/szym/display) for more details.
+
+### CycleGAN Datasets
+Download the CycleGAN datasets using the following script:
+```bash
+bash ./datasets/download_CycleGAN_dataset.sh dataset_name
+```
+- `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/).
+- `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/).
+- `maps`: 1096 training images scraped from Google Maps.
+- `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org/) using keywords `wild horse` and `zebra`
+- `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org/) using keywords `apple` and `navel orange`.
+- `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper.
+- `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using combination of tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853.
+- `iphone2dslr_flower`: both classe of images were downlaoded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper.
+
+To train a model on your own datasets, you need to create a data folder with two subdirectories `trainA` and `trainB` that contain images from domain A and B. You can test your model on your training set by setting ``phase='train'`` in `test.lua`. You can also create subdirectories like `testA` and `testB` if you have additional test data.
+
+You should **not** expect our method to work on any combination of two random datasets (e.g. `cats<->keyboards`). From our experiments, we find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`. `zebras<->horses` achieves compelling results while `cats<->dogs` completely fails. See the following section for more discussion.
+
+### pix2pix datasets
+Download the pix2pix datasets using the following script:
+```bash
+bash ./datasets/download_pix2pix_dataset.sh dataset_name
+```
+- `facades`: 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/).
+- `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/).
+- `maps`: 1096 training images scraped from Google Maps
+- `edges2shoes`: 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k/). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing.
+- `edges2handbags`: 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing.
+
+We provide a python script to generate pix2pix training data in the form of pairs of images {A,B}, where A and B are two different depictions of the same underlying scene. For example, these might be pairs {label map, photo} or {bw image, color image}. Then we can learn to translate A to B or B to A:
+
+Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `val`, `test`, etc. In `/path/to/data/A/train`, put training images in style A. In `/path/to/data/B/train`, put the corresponding images in style B. Repeat same for other data splits (`val`, `test`, etc).
+
+Corresponding images in a pair {A,B} must be the same size and have the same filename, e.g. `/path/to/data/A/train/1.jpg` is considered to correspond to `/path/to/data/B/train/1.jpg`.
+
+Once the data is formatted this way, call:
+```bash
+python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data
+```
+
+This will combine each pair of images (A,B) into a single image file, ready for training.
+
+## TODO
+- add Unet architecture
+- add one-direction test model
+- fully test instance normalization from [fast-neural-style project](https://github.com/darkstar112358/fast-neural-style)
+- fully test CPU mode and multi-GPU mode
+
+## Related Projects:
+[CycleGAN](https://github.com/junyanz/CycleGAN): Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
+[pix2pix](https://github.com/phillipi/pix2pix): Image-to-image translation using conditional adversarial nets
+[iGAN](https://github.com/junyanz/iGAN): Interactive Image Generation via Generative Adversarial Networks
+
+## Cat Paper Collection
+If you love cats, and love reading cool graphics, vision, and learning papers, please check out the Cat Paper Collection:
+[[Github]](https://github.com/junyanz/CatPapers) [[Webpage]](http://people.eecs.berkeley.edu/~junyanz/cat/cat_papers.html)
+
+## Acknowledgments
+Code is inspired by [pytorch-DCGAN](https://github.com/pytorch/examples/tree/master/dcgan).
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/data/__init__.py
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
new file mode 100644
index 0000000..01dbf89
--- /dev/null
+++ b/data/aligned_data_loader.py
@@ -0,0 +1,69 @@
+import random
+import torch.utils.data
+import torchvision.transforms as transforms
+from data.base_data_loader import BaseDataLoader
+from data.image_folder import ImageFolder
+from pdb import set_trace as st
+from builtins import object
+
+class PairedData(object):
+ def __init__(self, data_loader, fineSize):
+ self.data_loader = data_loader
+ self.fineSize = fineSize
+ # st()
+
+ def __iter__(self):
+ self.data_loader_iter = iter(self.data_loader)
+ return self
+
+ def __next__(self):
+ # st()
+ AB, AB_paths = next(self.data_loader_iter)
+ # st()
+ w_total = AB.size(3)
+ w = int(w_total / 2)
+ h = AB.size(2)
+
+ w_offset = random.randint(0, max(0, w - self.fineSize - 1))
+ h_offset = random.randint(0, max(0, h - self.fineSize - 1))
+
+ A = AB[:, :, h_offset:h_offset + self.fineSize,
+ w_offset:w_offset + self.fineSize]
+ B = AB[:, :, h_offset:h_offset + self.fineSize,
+ w + w_offset:w + w_offset + self.fineSize]
+
+ return {'A': A, 'A_paths': AB_paths, 'B': B, 'B_paths': AB_paths}
+
+
+class AlignedDataLoader(BaseDataLoader):
+ def initialize(self, opt):
+ BaseDataLoader.initialize(self, opt)
+ self.fineSize = opt.fineSize
+ transform = transforms.Compose([
+ # TODO: Scale
+ #transforms.Scale((opt.loadSize * 2, opt.loadSize)),
+ #transforms.CenterCrop(opt.fineSize),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))])
+
+ # Dataset A
+ dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase,
+ transform=transform, return_paths=True)
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=self.opt.batchSize,
+ shuffle=not self.opt.serial_batches,
+ num_workers=int(self.opt.nThreads))
+
+ self.dataset = dataset
+ self.paired_data = PairedData(data_loader, opt.fineSize)
+
+ def name(self):
+ return 'AlignedDataLoader'
+
+ def load_data(self):
+ return self.paired_data
+
+ def __len__(self):
+ return len(self.dataset)
diff --git a/data/base_data_loader.py b/data/base_data_loader.py
new file mode 100644
index 0000000..0e1deb5
--- /dev/null
+++ b/data/base_data_loader.py
@@ -0,0 +1,14 @@
+
+class BaseDataLoader():
+ def __init__(self):
+ pass
+
+ def initialize(self, opt):
+ self.opt = opt
+ pass
+
+ def load_data():
+ return None
+
+
+
diff --git a/data/data_loader.py b/data/data_loader.py
new file mode 100644
index 0000000..69035ea
--- /dev/null
+++ b/data/data_loader.py
@@ -0,0 +1,12 @@
+
+def CreateDataLoader(opt):
+ data_loader = None
+ if opt.align_data > 0:
+ from data.aligned_data_loader import AlignedDataLoader
+ data_loader = AlignedDataLoader()
+ else:
+ from data.unaligned_data_loader import UnalignedDataLoader
+ data_loader = UnalignedDataLoader()
+ print(data_loader.name())
+ data_loader.initialize(opt)
+ return data_loader
diff --git a/data/image_folder.py b/data/image_folder.py
new file mode 100644
index 0000000..44e15cb
--- /dev/null
+++ b/data/image_folder.py
@@ -0,0 +1,67 @@
+################################################################################
+# Code from
+# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
+# Modified the original code so that it also loads images from the current
+# directory as well as the subdirectories
+################################################################################
+
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir):
+ images = []
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
+
+ for root, _, fnames in sorted(os.walk(dir)):
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+
+ return images
+
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+
+class ImageFolder(data.Dataset):
+
+ def __init__(self, root, transform=None, return_paths=False,
+ loader=default_loader):
+ imgs = make_dataset(root)
+ if len(imgs) == 0:
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
+
+ self.root = root
+ self.imgs = imgs
+ self.transform = transform
+ self.return_paths = return_paths
+ self.loader = loader
+
+ def __getitem__(self, index):
+ path = self.imgs[index]
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.return_paths:
+ return img, path
+ else:
+ return img
+
+ def __len__(self):
+ return len(self.imgs)
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
new file mode 100644
index 0000000..95d9ac7
--- /dev/null
+++ b/data/unaligned_data_loader.py
@@ -0,0 +1,63 @@
+import torch.utils.data
+import torchvision.transforms as transforms
+from data.base_data_loader import BaseDataLoader
+from data.image_folder import ImageFolder
+from builtins import object
+
+
+class PairedData(object):
+ def __init__(self, data_loader_A, data_loader_B):
+ self.data_loader_A = data_loader_A
+ self.data_loader_B = data_loader_B
+
+ def __iter__(self):
+ self.data_loader_A_iter = iter(self.data_loader_A)
+ self.data_loader_B_iter = iter(self.data_loader_B)
+ return self
+
+ def __next__(self):
+ A, A_paths = next(self.data_loader_A_iter)
+ B, B_paths = next(self.data_loader_B_iter)
+ return {'A': A, 'A_paths': A_paths,
+ 'B': B, 'B_paths': B_paths}
+
+
+class UnalignedDataLoader(BaseDataLoader):
+ def initialize(self, opt):
+ BaseDataLoader.initialize(self, opt)
+ transform = transforms.Compose([
+ transforms.Scale(opt.loadSize),
+ transforms.CenterCrop(opt.fineSize),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))])
+
+ # Dataset A
+ dataset_A = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'A',
+ transform=transform, return_paths=True)
+ data_loader_A = torch.utils.data.DataLoader(
+ dataset_A,
+ batch_size=self.opt.batchSize,
+ shuffle=not self.opt.serial_batches,
+ num_workers=int(self.opt.nThreads))
+
+ # Dataset B
+ dataset_B = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'B',
+ transform=transform, return_paths=True)
+ data_loader_B = torch.utils.data.DataLoader(
+ dataset_B,
+ batch_size=self.opt.batchSize,
+ shuffle=not self.opt.serial_batches,
+ num_workers=int(self.opt.nThreads))
+ self.dataset_A = dataset_A
+ self.dataset_B = dataset_B
+ self.paired_data = PairedData(data_loader_A, data_loader_B)
+
+ def name(self):
+ return 'UnalignedDataLoader'
+
+ def load_data(self):
+ return self.paired_data
+
+ def __len__(self):
+ return len(self.dataset_A)
diff --git a/datasets/combine_A_and_B.py b/datasets/combine_A_and_B.py
new file mode 100644
index 0000000..4d1e2a2
--- /dev/null
+++ b/datasets/combine_A_and_B.py
@@ -0,0 +1,49 @@
+from pdb import set_trace as st
+import os
+import numpy as np
+import cv2
+import argparse
+
+parser = argparse.ArgumentParser('create image pairs')
+parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
+parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
+parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
+parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000)
+parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true')
+args = parser.parse_args()
+
+for arg in vars(args):
+ print('[%s] = ' % arg, getattr(args, arg))
+
+splits = os.listdir(args.fold_A)
+
+for sp in splits:
+ img_fold_A = os.path.join(args.fold_A, sp)
+ img_fold_B = os.path.join(args.fold_B, sp)
+ img_list = os.listdir(img_fold_A)
+ if args.use_AB:
+ img_list = [img_path for img_path in img_list if '_A.' in img_path]
+
+ num_imgs = min(args.num_imgs, len(img_list))
+ print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
+ img_fold_AB = os.path.join(args.fold_AB, sp)
+ if not os.path.isdir(img_fold_AB):
+ os.makedirs(img_fold_AB)
+ print('split = %s, number of images = %d' % (sp, num_imgs))
+ for n in range(num_imgs):
+ name_A = img_list[n]
+ path_A = os.path.join(img_fold_A, name_A)
+ if args.use_AB:
+ name_B = name_A.replace('_A.', '_B.')
+ else:
+ name_B = name_A
+ path_B = os.path.join(img_fold_B, name_B)
+ if os.path.isfile(path_A) and os.path.isfile(path_B):
+ name_AB = name_A
+ if args.use_AB:
+ name_AB = name_AB.replace('_A.', '.') # remove _A
+ path_AB = os.path.join(img_fold_AB, name_AB)
+ im_A = cv2.imread(path_A, cv2.CV_LOAD_IMAGE_COLOR)
+ im_B = cv2.imread(path_B, cv2.CV_LOAD_IMAGE_COLOR)
+ im_AB = np.concatenate([im_A, im_B], 1)
+ cv2.imwrite(path_AB, im_AB)
diff --git a/datasets/download_cyclegan_dataset.sh b/datasets/download_cyclegan_dataset.sh
new file mode 100644
index 0000000..1f0b163
--- /dev/null
+++ b/datasets/download_cyclegan_dataset.sh
@@ -0,0 +1,14 @@
+FILE=$1
+
+if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
+ echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
+ exit 1
+fi
+
+URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
+ZIP_FILE=./datasets/$FILE.zip
+TARGET_DIR=./datasets/$FILE/
+wget -N $URL -O $ZIP_FILE
+mkdir $TARGET_DIR
+unzip $ZIP_FILE -d ./datasets/
+rm $ZIP_FILE
diff --git a/datasets/download_pix2pix_dataset.sh b/datasets/download_pix2pix_dataset.sh
new file mode 100644
index 0000000..2d28e4f
--- /dev/null
+++ b/datasets/download_pix2pix_dataset.sh
@@ -0,0 +1,8 @@
+FILE=$1
+URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
+TAR_FILE=./datasets/$FILE.tar.gz
+TARGET_DIR=./datasets/$FILE/
+wget -N $URL -O $TAR_FILE
+mkdir $TARGET_DIR
+tar -zxvf $TAR_FILE -C ./datasets/
+rm $TAR_FILE \ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/models/__init__.py
diff --git a/models/base_model.py b/models/base_model.py
new file mode 100644
index 0000000..0ea83d8
--- /dev/null
+++ b/models/base_model.py
@@ -0,0 +1,56 @@
+import os
+import torch
+from pdb import set_trace as st
+
+class BaseModel():
+ def name(self):
+ return 'BaseModel'
+
+ def initialize(self, opt):
+ self.opt = opt
+ self.gpu_ids = opt.gpu_ids
+ self.isTrain = opt.isTrain
+ self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
+
+ def set_input(self, input):
+ self.input = input
+
+ def forward(self):
+ pass
+
+ # used in test time, no backprop
+ def test(self):
+ pass
+
+ def get_image_paths(self):
+ pass
+
+ def optimize_parameters(self):
+ pass
+
+ def get_current_visuals(self):
+ return self.input
+
+ def get_current_errors(self):
+ return {}
+
+ def save(self, label):
+ pass
+
+ # helper saving function that can be used by subclasses
+ def save_network(self, network, network_label, epoch_label, use_gpu):
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
+ save_path = os.path.join(self.save_dir, save_filename)
+ torch.save(network.cpu().state_dict(), save_path)
+ if use_gpu and torch.cuda.is_available():
+ network.cuda()
+
+ # helper loading function that can be used by subclasses
+ def load_network(self, network, network_label, epoch_label):
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
+ save_path = os.path.join(self.save_dir, save_filename)
+ network.load_state_dict(torch.load(save_path))
+
+ def update_learning_rate():
+ pass
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
new file mode 100644
index 0000000..c3b5b72
--- /dev/null
+++ b/models/cycle_gan_model.py
@@ -0,0 +1,222 @@
+import numpy as np
+import torch
+import os
+from collections import OrderedDict
+from pdb import set_trace as st
+from torch.autograd import Variable
+import itertools
+import util.util as util
+from util.image_pool import ImagePool
+from .base_model import BaseModel
+from . import networks
+import sys
+
+class CycleGANModel(BaseModel):
+ def name(self):
+ return 'CycleGANModel'
+
+ def initialize(self, opt):
+ BaseModel.initialize(self, opt)
+
+ nb = opt.batchSize
+ size = opt.fineSize
+ self.input_A = self.Tensor(nb, opt.input_nc, size, size)
+ self.input_B = self.Tensor(nb, opt.output_nc, size, size)
+
+ # load/define networks
+ # The naming conversion is different from those used in the paper
+ # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
+
+ self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
+ opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
+ self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
+ opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
+
+ if self.isTrain:
+ use_sigmoid = opt.no_lsgan
+ self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
+ opt.which_model_netD,
+ opt.n_layers_D, use_sigmoid, self.gpu_ids)
+ self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
+ opt.which_model_netD,
+ opt.n_layers_D, use_sigmoid, self.gpu_ids)
+ if not self.isTrain or opt.continue_train:
+ which_epoch = opt.which_epoch
+ self.load_network(self.netG_A, 'G_A', which_epoch)
+ self.load_network(self.netG_B, 'G_B', which_epoch)
+ if self.isTrain:
+ self.load_network(self.netD_A, 'D_A', which_epoch)
+ self.load_network(self.netD_B, 'D_B', which_epoch)
+
+ if self.isTrain:
+ self.old_lr = opt.lr
+ self.fake_A_pool = ImagePool(opt.pool_size)
+ self.fake_B_pool = ImagePool(opt.pool_size)
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
+ self.criterionCycle = torch.nn.L1Loss()
+ self.criterionIdt = torch.nn.L1Loss()
+ # initialize optimizers
+ self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
+ lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
+ lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
+ lr=opt.lr, betas=(opt.beta1, 0.999))
+
+ print('---------- Networks initialized -------------')
+ networks.print_network(self.netG_A)
+ networks.print_network(self.netG_B)
+ networks.print_network(self.netD_A)
+ networks.print_network(self.netD_B)
+ print('-----------------------------------------------')
+
+ def set_input(self, input):
+ AtoB = self.opt.which_direction is 'AtoB'
+ input_A = input['A' if AtoB else 'B']
+ input_B = input['B' if AtoB else 'A']
+ self.input_A.resize_(input_A.size()).copy_(input_A)
+ self.input_B.resize_(input_B.size()).copy_(input_B)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+ def forward(self):
+ self.real_A = Variable(self.input_A)
+ self.real_B = Variable(self.input_B)
+
+ def test(self):
+ self.real_A = Variable(self.input_A, volatile=True)
+ self.fake_B = self.netG_A.forward(self.real_A)
+ self.rec_A = self.netG_B.forward(self.fake_B)
+
+ self.real_B = Variable(self.input_B, volatile=True)
+ self.fake_A = self.netG_B.forward(self.real_B)
+ self.rec_B = self.netG_A.forward(self.fake_A)
+
+ #get image paths
+ def get_image_paths(self):
+ return self.image_paths
+
+ def backward_D_basic(self, netD, real, fake):
+ # Real
+ pred_real = netD.forward(real)
+ loss_D_real = self.criterionGAN(pred_real, True)
+ # Fake
+ pred_fake = netD.forward(fake.detach())
+ loss_D_fake = self.criterionGAN(pred_fake, False)
+ # Combined loss
+ loss_D = (loss_D_real + loss_D_fake) * 0.5
+ # backward
+ loss_D.backward()
+ return loss_D
+
+ def backward_D_A(self):
+ fake_B = self.fake_B_pool.query(self.fake_B)
+ self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
+
+ def backward_D_B(self):
+ fake_A = self.fake_A_pool.query(self.fake_A)
+ self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
+
+ def backward_G(self):
+ lambda_idt = self.opt.identity
+ lambda_A = self.opt.lambda_A
+ lambda_B = self.opt.lambda_B
+ # Identity loss
+ if lambda_idt > 0:
+ # G_A should be identity if real_B is fed.
+ self.idt_A = self.netG_A.forward(self.real_B)
+ self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
+ # G_B should be identity if real_A is fed.
+ self.idt_B = self.netG_B.forward(self.real_A)
+ self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
+ else:
+ self.loss_idt_A = 0
+ self.loss_idt_B = 0
+
+ # GAN loss
+ # D_A(G_A(A))
+ self.fake_B = self.netG_A.forward(self.real_A)
+ pred_fake = self.netD_A.forward(self.fake_B)
+ self.loss_G_A = self.criterionGAN(pred_fake, True)
+ # D_B(G_B(B))
+ self.fake_A = self.netG_B.forward(self.real_B)
+ pred_fake = self.netD_B.forward(self.fake_A)
+ self.loss_G_B = self.criterionGAN(pred_fake, True)
+ # Forward cycle loss
+ self.rec_A = self.netG_B.forward(self.fake_B)
+ self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
+ # Backward cycle loss
+ self.rec_B = self.netG_A.forward(self.fake_A)
+ self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
+ # combined loss
+ self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
+ self.loss_G.backward()
+
+ def optimize_parameters(self):
+ # forward
+ self.forward()
+ # G_A and G_B
+ self.optimizer_G.zero_grad()
+ self.backward_G()
+ self.optimizer_G.step()
+ # D_A
+ self.optimizer_D_A.zero_grad()
+ self.backward_D_A()
+ self.optimizer_D_A.step()
+ # D_B
+ self.optimizer_D_B.zero_grad()
+ self.backward_D_B()
+ self.optimizer_D_B.step()
+
+
+ def get_current_errors(self):
+ D_A = self.loss_D_A.data[0]
+ G_A = self.loss_G_A.data[0]
+ Cyc_A = self.loss_cycle_A.data[0]
+ D_B = self.loss_D_B.data[0]
+ G_B = self.loss_G_B.data[0]
+ Cyc_B = self.loss_cycle_B.data[0]
+ if self.opt.identity > 0.0:
+ idt_A = self.loss_idt_A.data[0]
+ idt_B = self.loss_idt_B.data[0]
+ return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
+ ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
+ else:
+ return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
+ ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])
+
+ def get_current_visuals(self):
+ real_A = util.tensor2im(self.real_A.data)
+ fake_B = util.tensor2im(self.fake_B.data)
+ rec_A = util.tensor2im(self.rec_A.data)
+ real_B = util.tensor2im(self.real_B.data)
+ fake_A = util.tensor2im(self.fake_A.data)
+ rec_B = util.tensor2im(self.rec_B.data)
+ if self.opt.identity > 0.0:
+ idt_A = util.tensor2im(self.idt_A.data)
+ idt_B = util.tensor2im(self.idt_B.data)
+ return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B),
+ ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)])
+ else:
+ return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
+ ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
+
+ def save(self, label):
+ use_gpu = self.gpu_ids is not None
+ self.save_network(self.netG_A, 'G_A', label, use_gpu)
+ self.save_network(self.netD_A, 'D_A', label, use_gpu)
+ self.save_network(self.netG_B, 'G_B', label, use_gpu)
+ self.save_network(self.netD_B, 'D_B', label, use_gpu)
+
+ def update_learning_rate(self):
+ lrd = self.opt.lr / self.opt.niter_decay
+ lr = self.old_lr - lrd
+ for param_group in self.optimizer_D_A.param_groups:
+ param_group['lr'] = lr
+ for param_group in self.optimizer_D_B.param_groups:
+ param_group['lr'] = lr
+ for param_group in self.optimizer_G.param_groups:
+ param_group['lr'] = lr
+
+ print('update learning rate: %f -> %f' % (self.old_lr, lr))
+ self.old_lr = lr
diff --git a/models/models.py b/models/models.py
new file mode 100644
index 0000000..7e790d0
--- /dev/null
+++ b/models/models.py
@@ -0,0 +1,15 @@
+
+def create_model(opt):
+ model = None
+ print(opt.model)
+ if opt.model == 'cycle_gan':
+ from .cycle_gan_model import CycleGANModel
+ assert(opt.align_data == False)
+ model = CycleGANModel()
+ if opt.model == 'pix2pix':
+ from .pix2pix_model import Pix2PixModel
+ assert(opt.align_data == True)
+ model = Pix2PixModel()
+ model.initialize(opt)
+ print("model [%s] was created" % (model.name()))
+ return model
diff --git a/models/networks.py b/models/networks.py
new file mode 100644
index 0000000..d41bd0e
--- /dev/null
+++ b/models/networks.py
@@ -0,0 +1,288 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from pdb import set_trace as st
+
+###############################################################################
+# Functions
+###############################################################################
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ m.weight.data.normal_(0.0, 0.02)
+ elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
+ m.weight.data.normal_(1.0, 0.02)
+ m.bias.data.fill_(0)
+
+
+def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]):
+ netG = None
+ use_gpu = len(gpu_ids) > 0
+ if norm == 'batch':
+ norm_layer = nn.BatchNorm2d
+ elif norm == 'instance':
+ norm_layer = InstanceNormalization
+ else:
+ print('normalization layer [%s] is not found' % norm)
+
+ assert(torch.cuda.is_available() == use_gpu)
+ if which_model_netG == 'resnet_9blocks':
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids)
+ elif which_model_netG == 'resnet_6blocks':
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids)
+ elif which_model_netG == 'unet':
+ netG = UnetGenerator(input_nc, output_nc, ngf, norm_layer, gpu_ids=gpu_ids)
+ else:
+ print('Generator model name [%s] is not recognized' % which_model_netG)
+ if use_gpu:
+ netG.cuda()
+ netG.apply(weights_init)
+ return netG
+
+
+def define_D(input_nc, ndf, which_model_netD,
+ n_layers_D=3, use_sigmoid=False, gpu_ids=[]):
+ netD = None
+ use_gpu = len(gpu_ids) > 0
+ assert(torch.cuda.is_available() == use_gpu)
+ if which_model_netD == 'basic':
+ netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
+ elif which_model_netD == 'n_layers':
+ netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, gpu_ids=gpu_ids)
+ else:
+ print('Discriminator model name [%s] is not recognized' %
+ which_model_netD)
+ if use_gpu:
+ netD.cuda()
+ netD.apply(weights_init)
+ return netD
+
+
+def print_network(net):
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ print(net)
+ print('Total number of parameters: %d' % num_params)
+
+
+##############################################################################
+# Classes
+##############################################################################
+
+
+# Defines the GAN loss used in LSGAN.
+# It is basically same as MSELoss, but it abstracts away the need to create
+# the target label tensor that has the same size as the input
+class GANLoss(nn.Module):
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
+ tensor=torch.FloatTensor):
+ super(GANLoss, self).__init__()
+ self.real_label = target_real_label
+ self.fake_label = target_fake_label
+ self.real_label_var = None
+ self.fake_label_var = None
+ self.Tensor = tensor
+ if use_lsgan:
+ self.loss = nn.MSELoss()
+ else:
+ self.loss = nn.BCELoss()
+
+ def get_target_tensor(self, input, target_is_real):
+ target_tensor = None
+ if target_is_real:
+ create_label = ((self.real_label_var is None) or
+ (self.real_label_var.numel() != input.numel()))
+ if create_label:
+ real_tensor = self.Tensor(input.size()).fill_(self.real_label)
+ self.real_label_var = Variable(real_tensor, requires_grad=False)
+ target_tensor = self.real_label_var
+ else:
+ create_label = ((self.fake_label_var is None) or
+ (self.fake_label_var.numel() != input.numel()))
+ if create_label:
+ fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
+ self.fake_label_var = Variable(fake_tensor, requires_grad=False)
+ target_tensor = self.fake_label_var
+ return target_tensor
+
+ def __call__(self, input, target_is_real):
+ target_tensor = self.get_target_tensor(input, target_is_real)
+ return self.loss(input, target_tensor)
+
+
+# Defines the generator that consists of Resnet blocks between a few
+# downsampling/upsampling operations.
+# Code and idea originally from Justin Johnson's architecture.
+# https://github.com/jcjohnson/fast-neural-style/
+class ResnetGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_blocks=6, gpu_ids=[]):
+ assert(n_blocks >= 0)
+ super(ResnetGenerator, self).__init__()
+ self.input_nc = input_nc
+ self.output_nc = output_nc
+ self.ngf = ngf
+ self.gpu_ids = gpu_ids
+
+ model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling):
+ mult = 2**i
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
+ stride=2, padding=1),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2**n_downsampling
+ for i in range(n_blocks):
+ model += [Resnet_block(ngf * mult, 'zero', norm_layer=norm_layer)]
+
+ for i in range(n_downsampling):
+ mult = 2**(n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)]
+ model += [nn.Tanh()]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
+ else:
+ return self.model(input)
+
+
+
+# Define a resnet block
+class Resnet_block(nn.Module):
+ def __init__(self, dim, padding_type, norm_layer):
+ super(Resnet_block, self).__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer)
+
+ def build_conv_block(self, dim, padding_type, norm_layer):
+ conv_block = []
+ p = 0
+ # TODO: support padding types
+ assert(padding_type == 'zero')
+ p = 1
+
+ # TODO: InstanceNorm
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim),
+ nn.ReLU(True)]
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ out = x + self.conv_block(x)
+ return out
+
+
+# Defines the Unet geneator.
+class UnetGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, gpu_ids=[]):
+ super(UnetGenerator, self).__init__()
+ self.input_nc = input_nc
+ self.output_nc = output_nc
+ self.ngf = ngf
+ self.gpu_ids = gpu_ids
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
+ else:
+ return self.model(input)
+
+
+
+# Defines the PatchGAN discriminator with the specified arguments.
+class NLayerDiscriminator(nn.Module):
+ def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]):
+ super(NLayerDiscriminator, self).__init__()
+ self.gpu_ids = gpu_ids
+
+ kw = 4
+ sequence = [
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=2),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers):
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
+ kernel_size=kw, stride=2, padding=2),
+ # TODO: use InstanceNorm
+ nn.BatchNorm2d(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
+ kernel_size=1, stride=2, padding=2),
+ # TODO: useInstanceNorm
+ nn.BatchNorm2d(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=1)]
+ sequence += [nn.Sigmoid()]
+
+ self.model = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
+ else:
+ return self.model(input)
+
+# Instance Normalization layer from
+# https://github.com/darkstar112358/fast-neural-style
+
+class InstanceNormalization(torch.nn.Module):
+ """InstanceNormalization
+ Improves convergence of neural-style.
+ ref: https://arxiv.org/pdf/1607.08022.pdf
+ """
+
+ def __init__(self, dim, eps=1e-5):
+ super(InstanceNormalization, self).__init__()
+ self.weight = nn.Parameter(torch.FloatTensor(dim))
+ self.bias = nn.Parameter(torch.FloatTensor(dim))
+ self.eps = eps
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ self.weight.data.uniform_()
+ self.bias.data.zero_()
+
+ def forward(self, x):
+ n = x.size(2) * x.size(3)
+ t = x.view(x.size(0), x.size(1), n)
+ mean = torch.mean(t, 2).unsqueeze(2).expand_as(x)
+ # Calculate the biased var. torch.var returns unbiased var
+ var = torch.var(t, 2).unsqueeze(2).expand_as(x) * ((n - 1) / float(n))
+ scale_broadcast = self.weight.unsqueeze(1).unsqueeze(1).unsqueeze(0)
+ scale_broadcast = scale_broadcast.expand_as(x)
+ shift_broadcast = self.bias.unsqueeze(1).unsqueeze(1).unsqueeze(0)
+ shift_broadcast = shift_broadcast.expand_as(x)
+ out = (x - mean) / torch.sqrt(var + self.eps)
+ out = out * scale_broadcast + shift_broadcast
+ return out
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
new file mode 100644
index 0000000..1d89b29
--- /dev/null
+++ b/models/pix2pix_model.py
@@ -0,0 +1,147 @@
+import numpy as np
+import torch
+import os
+from collections import OrderedDict
+from pdb import set_trace as st
+from torch.autograd import Variable
+import util.util as util
+from util.image_pool import ImagePool
+from .base_model import BaseModel
+from . import networks
+
+
+class Pix2PixModel(BaseModel):
+ def name(self):
+ return 'Pix2PixModel'
+
+ def initialize(self, opt):
+ BaseModel.initialize(self, opt)
+ self.isTrain = opt.isTrain
+ # define tensors
+ self.input_A = self.Tensor(opt.batchSize, opt.input_nc,
+ opt.fineSize, opt.fineSize)
+ self.input_B = self.Tensor(opt.batchSize, opt.output_nc,
+ opt.fineSize, opt.fineSize)
+
+ # load/define networks
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
+ opt.which_model_netG, opt.norm, self.gpu_ids)
+ if self.isTrain:
+ use_sigmoid = opt.no_lsgan
+ self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
+ opt.which_model_netD,
+ opt.n_layers_D, use_sigmoid, self.gpu_ids)
+ if not self.isTrain or opt.continue_train:
+ self.load_network(self.netG, 'G', opt.which_epoch)
+ if self.isTrain:
+ self.load_network(self.netD, 'D', opt.which_epoch)
+
+ if self.isTrain:
+ self.fake_AB_pool = ImagePool(opt.pool_size)
+ self.old_lr = opt.lr
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
+ self.criterionL1 = torch.nn.L1Loss()
+
+ # initialize optimizers
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
+ lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
+ lr=opt.lr, betas=(opt.beta1, 0.999))
+
+ print('---------- Networks initialized -------------')
+ networks.print_network(self.netG)
+ networks.print_network(self.netD)
+ print('-----------------------------------------------')
+
+ def set_input(self, input):
+ AtoB = self.opt.which_direction is 'AtoB'
+ input_A = input['A' if AtoB else 'B']
+ input_B = input['B' if AtoB else 'A']
+ self.input_A.resize_(input_A.size()).copy_(input_A)
+ self.input_B.resize_(input_B.size()).copy_(input_B)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+ def forward(self):
+ self.real_A = Variable(self.input_A)
+ self.fake_B = self.netG.forward(self.real_A)
+ self.real_B = Variable(self.input_B)
+
+ # no backprop gradients
+ def test(self):
+ self.real_A = Variable(self.input_A, volatile=True)
+ self.fake_B = self.netG.forward(self.real_A)
+ self.real_B = Variable(self.input_B, volatile=True)
+
+ #get image paths
+ def get_image_paths(self):
+ return self.image_paths
+
+ def backward_D(self):
+ # Fake
+ # stop backprop to the generator by detaching fake_B
+ fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
+ self.pred_fake = self.netD.forward(fake_AB.detach())
+ self.loss_D_fake = self.criterionGAN(self.pred_fake, False)
+
+ # Real
+ real_AB = torch.cat((self.real_A, self.real_B), 1)#.detach()
+ self.pred_real = self.netD.forward(real_AB)
+ self.loss_D_real = self.criterionGAN(self.pred_real, True)
+
+ # Combined loss
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
+
+ self.loss_D.backward()
+
+ def backward_G(self):
+ # First, G(A) should fake the discriminator
+ fake_AB = torch.cat((self.real_A, self.fake_B), 1)
+ pred_fake = self.netD.forward(fake_AB)
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True)
+
+ # Second, G(A) = B
+ self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A
+
+ self.loss_G = self.loss_G_GAN + self.loss_G_L1
+
+ self.loss_G.backward()
+
+ def optimize_parameters(self):
+ self.forward()
+
+ self.optimizer_D.zero_grad()
+ self.backward_D()
+ self.optimizer_D.step()
+
+ self.optimizer_G.zero_grad()
+ self.backward_G()
+ self.optimizer_G.step()
+
+ def get_current_errors(self):
+ return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
+ ('G_L1', self.loss_G_L1.data[0]),
+ ('D_real', self.loss_D_real.data[0]),
+ ('D_fake', self.loss_D_fake.data[0])
+ ])
+
+ def get_current_visuals(self):
+ real_A = util.tensor2im(self.real_A.data)
+ fake_B = util.tensor2im(self.fake_B.data)
+ real_B = util.tensor2im(self.real_B.data)
+ return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)])
+
+ def save(self, label):
+ use_gpu = self.gpu_ids is not None
+ self.save_network(self.netG, 'G', label, use_gpu)
+ self.save_network(self.netD, 'D', label, use_gpu)
+
+ def update_learning_rate(self):
+ lrd = self.opt.lr / self.opt.niter_decay
+ lr = self.old_lr - lrd
+ for param_group in self.optimizer_D.param_groups:
+ param_group['lr'] = lr
+ for param_group in self.optimizer_G.param_groups:
+ param_group['lr'] = lr
+ print('update learning rate: %f -> %f' % (self.old_lr, lr))
+ self.old_lr = lr
diff --git a/options/__init__.py b/options/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/options/__init__.py
diff --git a/options/base_options.py b/options/base_options.py
new file mode 100644
index 0000000..c569306
--- /dev/null
+++ b/options/base_options.py
@@ -0,0 +1,71 @@
+import argparse
+import os
+from util import util
+from pdb import set_trace as st
+class BaseOptions():
+ def __init__(self):
+ self.parser = argparse.ArgumentParser()
+ self.initialized = False
+
+ def initialize(self):
+ self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
+ self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
+ self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
+ self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
+ self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
+ self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
+ self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
+ self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
+ self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD')
+ self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG')
+ self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
+ self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2')
+ self.parser.add_argument('--flip' , action='store_true', help='if flip the images for data argumentation')
+ self.parser.add_argument('--name', type=str, default='experiment name', help='name of the experiment. It decides where to store samples and models')
+ self.parser.add_argument('--align_data', action='store_true',
+ help='if True, the datasets are loaded from "test" and "train" directories and the data pairs are aligned')
+ self.parser.add_argument('--model', type=str, default='cycle_gan',
+ help='chooses which model to use. cycle_gan, one_direction_test, pix2pix, ...')
+ self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
+ self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
+ self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
+ self.parser.add_argument('--norm', type=str, default='batch', help='batch normalization or instance normalization')
+ self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+ self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
+ self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
+ self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
+
+ self.initialized = True
+
+ def parse(self):
+ if not self.initialized:
+ self.initialize()
+ self.opt = self.parser.parse_args()
+ self.opt.isTrain = self.isTrain # train or test
+
+ str_ids = self.opt.gpu_ids.split(',')
+ self.opt.gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ self.opt.gpu_ids.append(id)
+
+ args = dict((name, getattr(self.opt, name)) for name in dir(self.opt)
+ if not name.startswith('_'))
+
+
+ print('------------ Options -------------')
+ for k, v in sorted(args.items()):
+ print('%s: %s' % (str(k), str(v)))
+ print('-------------- End ----------------')
+
+ # save to the disk
+ expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, 'opt.txt')
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write('------------ Options -------------\n')
+ for k, v in sorted(args.items()):
+ opt_file.write('%s: %s\n' % (str(k), str(v)))
+ opt_file.write('-------------- End ----------------\n')
+ return self.opt
diff --git a/options/test_options.py b/options/test_options.py
new file mode 100644
index 0000000..c4ecff6
--- /dev/null
+++ b/options/test_options.py
@@ -0,0 +1,12 @@
+from .base_options import BaseOptions
+
+class TestOptions(BaseOptions):
+ def initialize(self):
+ BaseOptions.initialize(self)
+ self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
+ self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
+ self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
+ self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
+ self.isTrain = False
diff --git a/options/train_options.py b/options/train_options.py
new file mode 100644
index 0000000..e7d4b3a
--- /dev/null
+++ b/options/train_options.py
@@ -0,0 +1,24 @@
+from .base_options import BaseOptions
+
+class TrainOptions(BaseOptions):
+ def initialize(self):
+ BaseOptions.initialize(self)
+ self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
+ self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
+ self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
+ self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
+ self.parser.add_argument('--save_display_freq', type=int, default=2500, help='save the current display of results every save_display_freq_iterations')
+ self.parser.add_argument('--continue_train', action='store_true', help='if continue training, load the latest model')
+ self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
+ self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
+ self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
+ self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
+ self.parser.add_argument('--ntrain', type=int, default=float("inf"), help='# of examples per epoch.')
+ self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
+ self.parser.add_argument('--no_lsgan', action='store_true', help='if true, do *not* use least square GAN, if false, use vanilla GAN')
+ self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
+ self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
+ self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
+ self.parser.add_argument('--preprocessing', type=str, default='resize_and_crop', help='resizing/cropping strategy')
+ self.isTrain = True
diff --git a/scripts/test_pix2pix.sh b/scripts/test_pix2pix.sh
new file mode 100644
index 0000000..7b056fe
--- /dev/null
+++ b/scripts/test_pix2pix.sh
@@ -0,0 +1 @@
+python test.py --dataroot=./datasets/facades --name facades_pix2pix --model pix2pix --align_data
diff --git a/scripts/train_cyclegan.sh b/scripts/train_cyclegan.sh
new file mode 100644
index 0000000..03f7fd9
--- /dev/null
+++ b/scripts/train_cyclegan.sh
@@ -0,0 +1 @@
+python train.py --dataroot=./datasets/horse2zebra --name horse2zebra_cyclegan --gpu_ids 0 --save_epoch_freq 5
diff --git a/scripts/train_pix2pix.sh b/scripts/train_pix2pix.sh
new file mode 100644
index 0000000..682c6c6
--- /dev/null
+++ b/scripts/train_pix2pix.sh
@@ -0,0 +1 @@
+python train.py --dataroot=./datasets/facades --name facades_pix2pix --which_model_netG resnet_9blocks --loadSize 286 --fineSize 256 --model pix2pix --align_data --which_direction BtoA --save_epoch_freq 25
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..22d092c
--- /dev/null
+++ b/test.py
@@ -0,0 +1,34 @@
+import time
+import os
+from options.test_options import TestOptions
+opt = TestOptions().parse() # set CUDA_VISIBLE_DEVICES before import torch
+
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+from util.visualizer import Visualizer
+from pdb import set_trace as st
+from util import html
+
+opt.nThreads = 1 # test code only supports nThreads=1
+opt.batchSize = 1 #test code only supports batchSize=1
+opt.serial_batches = True # no shuffle
+
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+model = create_model(opt)
+visualizer = Visualizer(opt)
+# create website
+web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
+webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
+# test
+for i, data in enumerate(dataset):
+ if i >= opt.how_many:
+ break
+ model.set_input(data)
+ model.test()
+ visuals = model.get_current_visuals()
+ img_path = model.get_image_paths()
+ print('process image... %s' % img_path)
+ visualizer.save_images(webpage, visuals, img_path)
+
+webpage.save()
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..e85042f
--- /dev/null
+++ b/train.py
@@ -0,0 +1,52 @@
+import time
+from options.train_options import TrainOptions
+opt = TrainOptions().parse() # set CUDA_VISIBLE_DEVICES before import torch
+
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+from util.visualizer import Visualizer
+
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+num_train = len(data_loader)
+print('#training images = %d' % num_train)
+
+model = create_model(opt)
+visualizer = Visualizer(opt)
+
+total_steps = 0
+
+for epoch in range(1, opt.niter + opt.niter_decay + 1):
+ epoch_start_time = time.time()
+ for i, data in enumerate(dataset):
+ iter_start_time = time.time()
+ total_steps += opt.batchSize
+ epoch_iter = total_steps % num_train
+ model.set_input(data)
+ model.optimize_parameters()
+
+ if total_steps % opt.display_freq == 0:
+ visualizer.display_current_results(model.get_current_visuals(), epoch)
+
+ if total_steps % opt.print_freq == 0:
+ errors = model.get_current_errors()
+ visualizer.print_current_errors(epoch, epoch_iter, errors, iter_start_time)
+ if opt.display_id > 0:
+ visualizer.plot_current_errors(epoch, epoch_iter, opt, errors)
+
+ if total_steps % opt.save_latest_freq == 0:
+ print('saving the latest model (epoch %d, total_steps %d)' %
+ (epoch, total_steps))
+ model.save('latest')
+
+ if epoch % opt.save_epoch_freq == 0:
+ print('saving the model at the end of epoch %d, iters %d' %
+ (epoch, total_steps))
+ model.save('latest')
+ model.save(epoch)
+
+ print('End of epoch %d / %d \t Time Taken: %d sec' %
+ (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
+
+ if epoch > opt.niter:
+ model.update_learning_rate()
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/util/__init__.py
diff --git a/util/display.py b/util/display.py
new file mode 100644
index 0000000..1403483
--- /dev/null
+++ b/util/display.py
@@ -0,0 +1,115 @@
+###############################################################################
+## Copied from https://github.com/szym/display.git
+## The python package installer is under development, so
+## the code was adopted.
+###############################################################################
+
+import base64
+import json
+import numpy
+
+try:
+ from urllib.parse import urlparse, urlencode
+ from urllib.request import urlopen, Request
+ from urllib.error import HTTPError
+except ImportError:
+ from urlparse import urlparse
+ from urllib import urlencode
+ from urllib2 import urlopen, Request, HTTPError
+from . import png
+
+__all__ = ['URL', 'image', 'images', 'plot']
+
+URL = 'http://localhost:8000/events'
+
+
+def uid():
+ return 'pane_%s' % uuid.uuid4()
+
+
+def send(**command):
+ command = json.dumps(command)
+ req = Request(URL, method='POST')
+ req.add_header('Content-Type', 'application/text')
+ req.data = command.encode('ascii')
+ try:
+ resp = urlopen(req)
+ return resp is not None
+ except:
+ raise
+ return False
+
+
+def pane(panetype, win, title, content):
+ win = win or uid()
+ send(command='pane', type=panetype, id=win, title=title, content=content)
+ return win
+
+
+def normalize(img, opts):
+ minval = opts.get('min')
+ if minval is None:
+ minval = numpy.amin(img)
+ maxval = opts.get('max')
+ if maxval is None:
+ maxval = numpy.amax(img)
+
+ return numpy.uint8((img - minval) * (255/(maxval - minval)))
+
+
+def to_rgb(img):
+ nchannels = img.shape[2] if img.ndim == 3 else 1
+ if nchannels == 3:
+ return img
+ if nchannels == 1:
+ return img[:, :, numpy.newaxis].repeat(3, axis=2)
+ raise ValueError('Image must be RGB or gray-scale')
+
+
+def image(img, **opts):
+ assert img.ndim == 2 or img.ndim == 3
+
+ if isinstance(img, list):
+ return images(img, opts)
+ # TODO: if img is a 3d tensor, then unstack it into a list of images
+
+ img = to_rgb(normalize(img, opts))
+ pngbytes = png.encode(img.tostring(), img.shape[1], img.shape[0])
+ imgdata = 'data:image/png;base64,' + base64.b64encode(pngbytes).decode('ascii')
+
+ return pane('image', opts.get('win'), opts.get('title'), content={
+ 'src': imgdata,
+ 'labels': opts.get('labels'),
+ 'width': opts.get('width'),
+ })
+
+
+def images(images, **opts):
+ # TODO: need to merge images into a single canvas
+ raise Exception('Not implemented')
+
+
+def plot(data, **opts):
+ """ Plot data as line chart.
+ Params:
+ data: either a 2-d numpy array or a list of lists.
+ win: pane id
+ labels: list of series names, first series is always the X-axis
+ see http://dygraphs.com/options.html for other supported options
+ """
+ dataset = {}
+ if type(data).__module__ == numpy.__name__:
+ dataset = data.tolist()
+ else:
+ dataset = data
+
+ # clone opts into options
+ options = dict(opts)
+ options['file'] = dataset
+ if options.get('labels'):
+ options['xlabel'] = options['labels'][0]
+
+ # Don't pass our options to dygraphs.
+ options.pop('win', None)
+
+ return pane('plot', opts.get('win'), opts.get('title'), content=options)
diff --git a/util/html.py b/util/html.py
new file mode 100644
index 0000000..c7956f1
--- /dev/null
+++ b/util/html.py
@@ -0,0 +1,64 @@
+import dominate
+from dominate.tags import *
+import os
+
+
+class HTML:
+ def __init__(self, web_dir, title, reflesh=0):
+ self.title = title
+ self.web_dir = web_dir
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ if not os.path.exists(self.web_dir):
+ os.makedirs(self.web_dir)
+ if not os.path.exists(self.img_dir):
+ os.makedirs(self.img_dir)
+ # print(self.img_dir)
+
+ self.doc = dominate.document(title=title)
+ if reflesh > 0:
+ with self.doc.head:
+ meta(http_equiv="reflesh", content=str(reflesh))
+
+ def get_image_dir(self):
+ return self.img_dir
+
+ def add_header(self, str):
+ with self.doc:
+ h3(str)
+
+ def add_table(self, border=1):
+ self.t = table(border=border, style="table-layout: fixed;")
+ self.doc.add(self.t)
+
+ def add_images(self, ims, txts, links, width=400):
+ self.add_table()
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__':
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims = []
+ txts = []
+ links = []
+ for n in range(4):
+ ims.append('image_%d.png' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.png' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/util/image_pool.py b/util/image_pool.py
new file mode 100644
index 0000000..b59e185
--- /dev/null
+++ b/util/image_pool.py
@@ -0,0 +1,33 @@
+import random
+import numpy as np
+import torch
+from pdb import set_trace as st
+from torch.autograd import Variable
+class ImagePool():
+ def __init__(self, pool_size):
+ self.pool_size = pool_size
+ if self.pool_size > 0:
+ self.num_imgs = 0
+ self.images = []
+
+ def query(self, images):
+ if self.pool_size == 0:
+ return images
+ return_images = []
+ for image in images.data:
+ image = torch.unsqueeze(image, 0)
+ if self.num_imgs < self.pool_size:
+ self.num_imgs = self.num_imgs + 1
+ self.images.append(image)
+ return_images.append(image)
+ else:
+ p = random.uniform(0, 1)
+ if p > 0.5:
+ random_id = random.randint(0, self.pool_size-1)
+ tmp = self.images[random_id].clone()
+ self.images[random_id] = image
+ return_images.append(tmp)
+ else:
+ return_images.append(image)
+ return_images = Variable(torch.cat(return_images, 0))
+ return return_images
diff --git a/util/png.py b/util/png.py
new file mode 100644
index 0000000..0936cf0
--- /dev/null
+++ b/util/png.py
@@ -0,0 +1,33 @@
+import struct
+import zlib
+
+def encode(buf, width, height):
+ """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """
+ assert (width * height * 3 == len(buf))
+ bpp = 3
+
+ def raw_data():
+ # reverse the vertical line order and add null bytes at the start
+ row_bytes = width * bpp
+ for row_start in range((height - 1) * width * bpp, -1, -row_bytes):
+ yield b'\x00'
+ yield buf[row_start:row_start + row_bytes]
+
+ def chunk(tag, data):
+ return [
+ struct.pack("!I", len(data)),
+ tag,
+ data,
+ struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag)))
+ ]
+
+ SIGNATURE = b'\x89PNG\r\n\x1a\n'
+ COLOR_TYPE_RGB = 2
+ COLOR_TYPE_RGBA = 6
+ bit_depth = 8
+ return b''.join(
+ [ SIGNATURE ] +
+ chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) +
+ chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) +
+ chunk(b'IEND', b'')
+ )
diff --git a/util/util.py b/util/util.py
new file mode 100644
index 0000000..781239f
--- /dev/null
+++ b/util/util.py
@@ -0,0 +1,71 @@
+from __future__ import print_function
+import torch
+import numpy as np
+from PIL import Image
+import inspect, re
+import numpy as np
+import os
+import collections
+
+# Converts a Tensor into a Numpy array
+# |imtype|: the desired type of the converted numpy array
+def tensor2im(image_tensor, imtype=np.uint8):
+ image_numpy = image_tensor[0].cpu().float().numpy()
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
+ return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name='network'):
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+def save_image(image_numpy, image_path):
+ image_pil = Image.fromarray(image_numpy)
+ image_pil.save(image_path)
+
+def info(object, spacing=10, collapse=1):
+ """Print methods and doc strings.
+ Takes module, class, list, dictionary, or string."""
+ methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
+ processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
+ print( "\n".join(["%s %s" %
+ (method.ljust(spacing),
+ processFunc(str(getattr(object, method).__doc__)))
+ for method in methodList]) )
+
+def varname(p):
+ for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
+ m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
+ if m:
+ return m.group(1)
+
+def print_numpy(x, val=True, shp=False):
+ x = x.astype(np.float64)
+ if shp:
+ print('shape,', x.shape)
+ if val:
+ x = x.flatten()
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
diff --git a/util/visualizer.py b/util/visualizer.py
new file mode 100644
index 0000000..0b8578e
--- /dev/null
+++ b/util/visualizer.py
@@ -0,0 +1,86 @@
+import numpy as np
+import os
+import ntpath
+import time
+from . import util
+from . import html
+from pdb import set_trace as st
+
+class Visualizer():
+ def __init__(self, opt):
+ # self.opt = opt
+ self.display_id = opt.display_id
+ if self.display_id > 0:
+ from . import display
+ self.display = display
+ else:
+ from . import html
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ self.name = opt.name
+ self.win_size = opt.display_winsize
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+
+
+ # |visuals|: dictionary of images to display or save
+ def display_current_results(self, visuals, epoch):
+ if self.display_id > 0: # show images in the browser
+ idx = 0
+ for label, image_numpy in visuals:
+ image_numpy = np.flipud(image_numpy)
+ self.display.image(image_numpy, title=label,
+ win=self.display_id + idx)
+ idx += 1
+ else: # save images to a web directory
+ for label, image_numpy in visuals.items():
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
+ util.save_image(image_numpy, img_path)
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims = []
+ txts = []
+ links = []
+
+ for label, image_numpy in visuals.items():
+ img_path = 'epoch%.3d_%s.png' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ webpage.save()
+ # st()
+ # errors: dictionary of error labels and values
+ def plot_current_errors(self, epoch, i, opt, errors):
+ pass
+
+ # errors: same format as |errors| of plotCurrentErrors
+ def print_current_errors(self, epoch, i, errors, start_time):
+ message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, time.time() - start_time)
+ for k, v in errors.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message)
+
+ # save image to the disk
+ def save_images(self, webpage, visuals, image_path):
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims = []
+ txts = []
+ links = []
+
+ for label, image_numpy in visuals.items():
+ image_name = '%s_%s.png' % (name, label)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(image_numpy, save_path)
+
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ webpage.add_images(ims, txts, links, width=self.win_size)