diff options
| author | tingchunw <tingchunw@nvidia.com> | 2017-12-04 16:52:46 -0800 |
|---|---|---|
| committer | tingchunw <tingchunw@nvidia.com> | 2017-12-04 16:52:46 -0800 |
| commit | 9054cf9b0c327a5077fd0793abe178f400da3315 (patch) | |
| tree | 3c69c07bdcba86c47d8442648fd69c0434e04136 | |
| parent | f9e9999541d67a908a169cc88407675133130e1f (diff) | |
first commit
85 files changed, 1982 insertions, 5 deletions
diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..681efd0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,40 @@ +debug* +checkpoints/ +results/ +build/ +dist/ +torch.egg-info/ +*/**/__pycache__ +torch/version.py +torch/csrc/generic/TensorMethods.cpp +torch/lib/*.so* +torch/lib/*.dylib* +torch/lib/*.h +torch/lib/build +torch/lib/tmp_install +torch/lib/include +torch/lib/torch_shm_manager +torch/csrc/cudnn/cuDNN.cpp +torch/csrc/nn/THNN.cwrap +torch/csrc/nn/THNN.cpp +torch/csrc/nn/THCUNN.cwrap +torch/csrc/nn/THCUNN.cpp +torch/csrc/nn/THNN_generic.cwrap +torch/csrc/nn/THNN_generic.cpp +torch/csrc/nn/THNN_generic.h +docs/src/**/* +test/data/legacy_modules.t7 +test/data/gpu_tensors.pt +test/htmlcov +test/.coverage +*/*.pyc +*/**/*.pyc +*/**/**/*.pyc +*/**/**/**/*.pyc +*/**/**/**/**/*.pyc +*/*.so* +*/**/*.so* +*/**/*.dylib* +test/data/legacy_serialized.pt +*.DS_Store +*~ diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100755 index 0000000..7406cd7 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,43 @@ +Copyright (C) 2017 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. +All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). + +Permission to use, copy, modify, and distribute this software and its documentation +for any non-commercial purpose is hereby granted without fee, provided that the above +copyright notice appear in all copies and that both that copyright notice and this +permission notice appear in supporting documentation, and that the name of the author +not be used in advertising or publicity pertaining to distribution of the software +without specific, written prior permission. + +THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. +IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL +DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING +OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + +--------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ---------------- +Copyright (c) 2017, Jun-Yan Zhu and Taesung Park +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. @@ -10,10 +10,6 @@ Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photor <sup>1</sup>NVIDIA Corporation, <sup>2</sup>UC Berkeley
In arxiv, 2017.
-## Release notice
-The code is ready to publish but still under final approval process. It should be approved in a couple of days.<br>
-If you want to get notified once the code is released, please subscribe [here](https://tcwang0509.github.io/pix2pixHD/subscribe.html).
-
## Image-to-image translation at 2k/1k resolution
- Our label-to-streetview results
<p align='center'>
@@ -53,7 +49,69 @@ If you want to get notified once the code is released, please subscribe [here](h <img src='imgs/face_short.gif' width='490'/>
</p>
-### Citation
+## Prerequisites
+- Linux or macOS
+- Python 2 or 3
+- NVIDIA GPU (12G or 24G memory) + CUDA cuDNN
+
+## Getting Started
+### Installation
+- Install PyTorch and dependencies from http://pytorch.org
+- Install python libraries [dominate](https://github.com/Knio/dominate).
+```bash
+pip install dominate
+```
+- Clone this repo:
+```bash
+git clone https://github.com/NVIDIA/pix2pixHD
+cd pix2pixHD
+```
+
+
+### Testing
+- A few example Cityscapes test images are included in the `datasets` folder.
+- Please download the pre-trained Cityscapes model from [here](https://drive.google.com/file/d/1h9SykUnuZul7J3Nbms2QGH1wa85nbN2-/view?usp=sharing) (google drive link), and put it under `./checkpoints/label2city_1024p/`
+- Test the model (`bash ./scripts/test_1024p.sh`):
+```bash
+#!./scripts/test_1024p.sh
+python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none
+```
+The test results will be saved to a html file here: `./results/label2city_1024p/test_latest/index.html`.
+
+More example scripts can be found in the `scripts` directory.
+
+
+### Dataset
+- We use the Cityscapes dataset. To train a model on the full dataset, please download it from the [official website](https://www.cityscapes-dataset.com/) (registration required).
+After downloading, please put it under the `datasets` folder in the same way the example images are provided.
+
+
+### Training
+- Train a model at 1024 x 512 resolution (`bash ./scripts/train_512p.sh`):
+```bash
+#!./scripts/train_512p.sh
+python train.py --name label2city_512p
+```
+- To view training results, please checkout intermediate results in `./checkpoints/label2city_512p/web/index.html`.
+If you have tensorflow installed, you can see tensorboard logs in `./checkpoints/label2city_512p/logs` by adding `--tf_log` to the training scripts.
+
+### Multi-GPU training
+- Train a model using multiple GPUs (`bash ./scripts/train_512p_multigpu.sh`):
+```bash
+#!./scripts/train_512p_multigpu.sh
+python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7
+```
+Note: this is not tested and we trained our model using single GPU only. Please use at your own discretion.
+
+### Training at full resolution
+- To train the images at full resolution (2048 x 1024) requires a GPU with 24G memory (`bash ./scripts/train_1024p_24G.sh`).
+If only GPUs with 12G memory are available, please use the 12G script (`bash ./scripts/train_1024p_12G.sh`), which will crop the images during training. Performance is not guaranteed using this script.
+
+## More Training/test Details
+- Flags: see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags.
+
+
+## Citation
If you find this useful for your research, please use the following.
@@ -65,3 +123,6 @@ If you find this useful for your research, please use the following. year={2017}
}
```
+
+## Acknowledgments
+This code borrows heavily from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
diff --git a/_config.yml b/_config.yml new file mode 100755 index 0000000..2f7efbe --- /dev/null +++ b/_config.yml @@ -0,0 +1 @@ +theme: jekyll-theme-minimal
\ No newline at end of file diff --git a/data/__init__.py b/data/__init__.py new file mode 100755 index 0000000..e69de29 --- /dev/null +++ b/data/__init__.py diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py new file mode 100755 index 0000000..50390f3 --- /dev/null +++ b/data/aligned_dataset.py @@ -0,0 +1,76 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os.path +import random +import torchvision.transforms as transforms +import torch +from data.base_dataset import BaseDataset, get_params, get_transform, normalize +from data.image_folder import make_dataset +from PIL import Image +import numpy as np + +class AlignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + + ### label maps + self.dir_label = os.path.join(opt.dataroot, opt.phase + '_label') + self.label_paths = sorted(make_dataset(self.dir_label)) + + ### real images + if opt.isTrain: + self.dir_image = os.path.join(opt.dataroot, opt.phase + '_img') + self.image_paths = sorted(make_dataset(self.dir_image)) + + ### instance maps + if not opt.no_instance: + self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst') + self.inst_paths = sorted(make_dataset(self.dir_inst)) + + ### load precomputed instance-wise encoded features + if opt.load_features: + self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat') + print('----------- loading features from %s ----------' % self.dir_feat) + self.feat_paths = sorted(make_dataset(self.dir_feat)) + + self.dataset_size = len(self.label_paths) + + def __getitem__(self, index): + ### label maps + label_path = self.label_paths[index] + label = Image.open(label_path) + params = get_params(self.opt, label.size) + transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) + label_tensor = transform_label(label) * 255.0 + + image_tensor = inst_tensor = feat_tensor = 0 + ### real images + if self.opt.isTrain: + image_path = self.image_paths[index] + image = Image.open(image_path).convert('RGB') + transform_image = get_transform(self.opt, params) + image_tensor = transform_image(image) + + ### if using instance maps + if not self.opt.no_instance: + inst_path = self.inst_paths[index] + inst = Image.open(inst_path) + inst_tensor = transform_label(inst) + + if self.opt.load_features: + feat_path = self.feat_paths[index] + feat = Image.open(feat_path).convert('RGB') + norm = normalize() + feat_tensor = norm(transform_label(feat)) + + input_dict = {'label': label_tensor, 'inst': inst_tensor, 'image': image_tensor, + 'feat': feat_tensor, 'path': label_path} + + return input_dict + + def __len__(self): + return len(self.label_paths) + + def name(self): + return 'AlignedDataset'
\ No newline at end of file diff --git a/data/base_data_loader.py b/data/base_data_loader.py new file mode 100755 index 0000000..0e1deb5 --- /dev/null +++ b/data/base_data_loader.py @@ -0,0 +1,14 @@ + +class BaseDataLoader(): + def __init__(self): + pass + + def initialize(self, opt): + self.opt = opt + pass + + def load_data(): + return None + + + diff --git a/data/base_dataset.py b/data/base_dataset.py new file mode 100755 index 0000000..038d3d2 --- /dev/null +++ b/data/base_dataset.py @@ -0,0 +1,92 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +import numpy as np +import random + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + def initialize(self, opt): + pass + +def get_params(opt, size): + w, h = size + new_h = h + new_w = w + if opt.resize_or_crop == 'resize_and_crop': + new_h = new_w = opt.loadSize + elif opt.resize_or_crop == 'scale_width_and_crop': + new_w = opt.loadSize + new_h = opt.loadSize * h // w + + x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) + y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) + + flip = random.random() > 0.5 + return {'crop_pos': (x, y), 'flip': flip} + +def get_transform(opt, params, method=Image.BICUBIC, normalize=True): + transform_list = [] + if 'resize' in opt.resize_or_crop: + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Scale(osize, method)) + elif 'scale_width' in opt.resize_or_crop: + transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) + + if 'crop' in opt.resize_or_crop: + transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) + + if opt.resize_or_crop == 'none': + base = float(2 ** opt.n_downsample_global) + if opt.netG == 'local': + base *= (2 ** opt.n_local_enhancers) + transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) + + transform_list += [transforms.ToTensor()] + + if normalize: + transform_list += [transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +def normalize(): + return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + +def __make_power_2(img, base, method=Image.BICUBIC): + ow, oh = img.size + h = int(round(oh / base) * base) + w = int(round(ow / base) * base) + if (h == oh) and (w == ow): + return img + return img.resize((w, h), method) + +def __scale_width(img, target_width, method=Image.BICUBIC): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), method) + +def __crop(img, pos, size): + ow, oh = img.size + x1, y1 = pos + tw = th = size + if (ow > tw or oh > th): + return img.crop((x1, y1, x1 + tw, y1 + th)) + return img + +def __flip(img, flip): + if flip: + return img.transpose(Image.FLIP_LEFT_RIGHT) + return img diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py new file mode 100755 index 0000000..0b98254 --- /dev/null +++ b/data/custom_dataset_data_loader.py @@ -0,0 +1,31 @@ +import torch.utils.data +from data.base_data_loader import BaseDataLoader + + +def CreateDataset(opt): + dataset = None + from data.aligned_dataset import AlignedDataset + dataset = AlignedDataset() + + print("dataset [%s] was created" % (dataset.name())) + dataset.initialize(opt) + return dataset + +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = CreateDataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self.dataloader + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) diff --git a/data/data_loader.py b/data/data_loader.py new file mode 100755 index 0000000..2a4433a --- /dev/null +++ b/data/data_loader.py @@ -0,0 +1,7 @@ + +def CreateDataLoader(opt): + from data.custom_dataset_data_loader import CustomDatasetDataLoader + data_loader = CustomDatasetDataLoader() + print(data_loader.name()) + data_loader.initialize(opt) + return data_loader diff --git a/data/image_folder.py b/data/image_folder.py new file mode 100755 index 0000000..16a447c --- /dev/null +++ b/data/image_folder.py @@ -0,0 +1,68 @@ +############################################################################### +# Code from +# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py +# Modified the original code so that it also loads images from the current +# directory as well as the subdirectories +############################################################################### + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + return images + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/datasets/cityscapes/test_inst/frankfurt_000000_000576_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000000_000576_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..01da7ed --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000000_000576_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000000_001236_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000000_001236_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..75506bc --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000000_001236_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000000_003357_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000000_003357_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..9bd27b0 --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000000_003357_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000000_011810_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000000_011810_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..df84eee --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000000_011810_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000000_012868_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000000_012868_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..ba1f7aa --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000000_012868_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_013710_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_013710_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..d05b7db --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000001_013710_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_015328_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_015328_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..32d62a3 --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000001_015328_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_023769_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_023769_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..9eef682 --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000001_023769_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_028335_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_028335_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..b1909d5 --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000001_028335_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_032711_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_032711_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..ac2e293 --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000001_032711_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_033655_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_033655_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..de7328e --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000001_033655_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_042733_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_042733_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..a98d096 --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000001_042733_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_047552_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_047552_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..ab569e3 --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000001_047552_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_054640_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_054640_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..5f246a6 --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000001_054640_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_055387_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_055387_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..2e7d01f --- /dev/null +++ b/datasets/cityscapes/test_inst/frankfurt_000001_055387_gtFine_instanceIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000000_000576_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000000_000576_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..8c9464c --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000000_000576_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000000_001236_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000000_001236_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..9f0ca9f --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000000_001236_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000000_003357_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000000_003357_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..1035e55 --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000000_003357_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000000_011810_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000000_011810_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..a86913b --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000000_011810_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000000_012868_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000000_012868_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..fe81c83 --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000000_012868_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000001_013710_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_013710_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..72b4be4 --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000001_013710_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000001_015328_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_015328_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..afefb6b --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000001_015328_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000001_023769_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_023769_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..f3af9df --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000001_023769_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000001_028335_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_028335_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..5e65e3e --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000001_028335_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000001_032711_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_032711_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..ba07b73 --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000001_032711_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000001_033655_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_033655_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..77f519c --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000001_033655_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000001_042733_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_042733_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..ba08f1d --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000001_042733_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000001_047552_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_047552_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..5dff09a --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000001_047552_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000001_054640_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_054640_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..cb2ab2b --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000001_054640_gtFine_labelIds.png diff --git a/datasets/cityscapes/test_label/frankfurt_000001_055387_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_055387_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..b00ef7e --- /dev/null +++ b/datasets/cityscapes/test_label/frankfurt_000001_055387_gtFine_labelIds.png diff --git a/datasets/cityscapes/train_img/aachen_000000_000019_leftImg8bit.png b/datasets/cityscapes/train_img/aachen_000000_000019_leftImg8bit.png Binary files differnew file mode 100755 index 0000000..0e6867e --- /dev/null +++ b/datasets/cityscapes/train_img/aachen_000000_000019_leftImg8bit.png diff --git a/datasets/cityscapes/train_img/aachen_000001_000019_leftImg8bit.png b/datasets/cityscapes/train_img/aachen_000001_000019_leftImg8bit.png Binary files differnew file mode 100755 index 0000000..d5a96ce --- /dev/null +++ b/datasets/cityscapes/train_img/aachen_000001_000019_leftImg8bit.png diff --git a/datasets/cityscapes/train_img/aachen_000002_000019_leftImg8bit.png b/datasets/cityscapes/train_img/aachen_000002_000019_leftImg8bit.png Binary files differnew file mode 100755 index 0000000..10ce563 --- /dev/null +++ b/datasets/cityscapes/train_img/aachen_000002_000019_leftImg8bit.png diff --git a/datasets/cityscapes/train_img/aachen_000003_000019_leftImg8bit.png b/datasets/cityscapes/train_img/aachen_000003_000019_leftImg8bit.png Binary files differnew file mode 100755 index 0000000..3027fe1 --- /dev/null +++ b/datasets/cityscapes/train_img/aachen_000003_000019_leftImg8bit.png diff --git a/datasets/cityscapes/train_img/aachen_000004_000019_leftImg8bit.png b/datasets/cityscapes/train_img/aachen_000004_000019_leftImg8bit.png Binary files differnew file mode 100755 index 0000000..26945fc --- /dev/null +++ b/datasets/cityscapes/train_img/aachen_000004_000019_leftImg8bit.png diff --git a/datasets/cityscapes/train_inst/aachen_000000_000019_gtFine_instanceIds.png b/datasets/cityscapes/train_inst/aachen_000000_000019_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..f4ee222 --- /dev/null +++ b/datasets/cityscapes/train_inst/aachen_000000_000019_gtFine_instanceIds.png diff --git a/datasets/cityscapes/train_inst/aachen_000001_000019_gtFine_instanceIds.png b/datasets/cityscapes/train_inst/aachen_000001_000019_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..dd69137 --- /dev/null +++ b/datasets/cityscapes/train_inst/aachen_000001_000019_gtFine_instanceIds.png diff --git a/datasets/cityscapes/train_inst/aachen_000002_000019_gtFine_instanceIds.png b/datasets/cityscapes/train_inst/aachen_000002_000019_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..bdad5e3 --- /dev/null +++ b/datasets/cityscapes/train_inst/aachen_000002_000019_gtFine_instanceIds.png diff --git a/datasets/cityscapes/train_inst/aachen_000003_000019_gtFine_instanceIds.png b/datasets/cityscapes/train_inst/aachen_000003_000019_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..91a035b --- /dev/null +++ b/datasets/cityscapes/train_inst/aachen_000003_000019_gtFine_instanceIds.png diff --git a/datasets/cityscapes/train_inst/aachen_000004_000019_gtFine_instanceIds.png b/datasets/cityscapes/train_inst/aachen_000004_000019_gtFine_instanceIds.png Binary files differnew file mode 100755 index 0000000..0f5fc70 --- /dev/null +++ b/datasets/cityscapes/train_inst/aachen_000004_000019_gtFine_instanceIds.png diff --git a/datasets/cityscapes/train_label/aachen_000000_000019_gtFine_labelIds.png b/datasets/cityscapes/train_label/aachen_000000_000019_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..eed7ee6 --- /dev/null +++ b/datasets/cityscapes/train_label/aachen_000000_000019_gtFine_labelIds.png diff --git a/datasets/cityscapes/train_label/aachen_000001_000019_gtFine_labelIds.png b/datasets/cityscapes/train_label/aachen_000001_000019_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..e9c25ee --- /dev/null +++ b/datasets/cityscapes/train_label/aachen_000001_000019_gtFine_labelIds.png diff --git a/datasets/cityscapes/train_label/aachen_000002_000019_gtFine_labelIds.png b/datasets/cityscapes/train_label/aachen_000002_000019_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..c96ab17 --- /dev/null +++ b/datasets/cityscapes/train_label/aachen_000002_000019_gtFine_labelIds.png diff --git a/datasets/cityscapes/train_label/aachen_000003_000019_gtFine_labelIds.png b/datasets/cityscapes/train_label/aachen_000003_000019_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..da05594 --- /dev/null +++ b/datasets/cityscapes/train_label/aachen_000003_000019_gtFine_labelIds.png diff --git a/datasets/cityscapes/train_label/aachen_000004_000019_gtFine_labelIds.png b/datasets/cityscapes/train_label/aachen_000004_000019_gtFine_labelIds.png Binary files differnew file mode 100755 index 0000000..bb30bd9 --- /dev/null +++ b/datasets/cityscapes/train_label/aachen_000004_000019_gtFine_labelIds.png diff --git a/encode_features.py b/encode_features.py new file mode 100755 index 0000000..0e97da8 --- /dev/null +++ b/encode_features.py @@ -0,0 +1,57 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+from options.train_options import TrainOptions
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+import numpy as np
+import os, time
+import util.util as util
+from torch.autograd import Variable
+
+opt = TrainOptions().parse()
+opt.nThreads = 1
+opt.batchSize = 1
+opt.serial_batches = True
+opt.no_flip = True
+opt.instance_feat = True
+
+name = 'features'
+save_path = os.path.join(opt.checkpoints_dir, opt.name)
+
+############ Initialize #########
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+dataset_size = len(data_loader)
+model = create_model(opt)
+
+########### Encode features ###########
+reencode = True
+if reencode:
+ features = {}
+ for label in range(opt.label_nc):
+ features[label] = np.zeros((0, opt.feat_num+1))
+ for i, data in enumerate(dataset):
+ feat = model.module.encode_features(data['image'], data['inst'])
+ for label in range(opt.label_nc):
+ features[label] = np.append(features[label], feat[label], axis=0)
+
+ print('%d / %d images' % (i+1, dataset_size))
+ save_name = os.path.join(save_path, name + '.npy')
+ np.save(save_name, features)
+
+############## Clustering ###########
+n_clusters = opt.n_clusters
+load_name = os.path.join(save_path, name + '.npy')
+features = np.load(load_name).item()
+from sklearn.cluster import KMeans
+centers = {}
+for label in range(opt.label_nc):
+ feat = features[label]
+ feat = feat[feat[:,-1] > 0.5, :-1]
+ if feat.shape[0]:
+ n_clusters = min(feat.shape[0], opt.n_clusters)
+ kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat)
+ centers[label] = kmeans.cluster_centers_
+save_name = os.path.join(save_path, name + '_clustered_%03d.npy' % opt.n_clusters)
+np.save(save_name, centers)
+print('saving to %s' % save_name)
\ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100755 index 0000000..e69de29 --- /dev/null +++ b/models/__init__.py diff --git a/models/base_model.py b/models/base_model.py new file mode 100755 index 0000000..d3879d0 --- /dev/null +++ b/models/base_model.py @@ -0,0 +1,86 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os +import torch + +class BaseModel(torch.nn.Module): + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if len(gpu_ids) and torch.cuda.is_available(): + network.cuda() + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label, save_dir=''): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + if not save_dir: + save_dir = self.save_dir + save_path = os.path.join(save_dir, save_filename) + if not os.path.isfile(save_path): + print('%s not exists yet!' % save_path) + if network_label == 'G': + raise('Generator must exist!') + else: + #network.load_state_dict(torch.load(save_path)) + try: + network.load_state_dict(torch.load(save_path)) + except: + pretrained_dict = torch.load(save_path) + model_dict = network.state_dict() + try: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + network.load_state_dict(pretrained_dict) + print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) + except: + print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) + from sets import Set + not_initialized = Set() + for k, v in pretrained_dict.items(): + if v.size() == model_dict[k].size(): + model_dict[k] = v + + for k, v in model_dict.items(): + if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): + not_initialized.add(k.split('.')[0]) + print(sorted(not_initialized)) + network.load_state_dict(model_dict) + + def update_learning_rate(): + pass diff --git a/models/models.py b/models/models.py new file mode 100755 index 0000000..351483c --- /dev/null +++ b/models/models.py @@ -0,0 +1,14 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import torch + +def create_model(opt): + from .pix2pixHD_model import Pix2PixHDModel + model = Pix2PixHDModel() + model.initialize(opt) + print("model [%s] was created" % (model.name())) + + if opt.isTrain and len(opt.gpu_ids): + model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) + + return model diff --git a/models/networks.py b/models/networks.py new file mode 100755 index 0000000..a673a56 --- /dev/null +++ b/models/networks.py @@ -0,0 +1,421 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.autograd import Variable +import numpy as np +import math +import torch.nn.functional as F +import copy + +############################################################################### +# Functions +############################################################################### +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + +def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, + n_blocks_local=3, norm='instance', gpu_ids=[]): + norm_layer = get_norm_layer(norm_type=norm) + if netG == 'global': + netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer) + elif netG == 'local': + netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, + n_local_enhancers, n_blocks_local, norm_layer) + elif netG == 'encoder': + netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer) + else: + raise('generator not implemented!') + print(netG) + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + netG.cuda(device_id=gpu_ids[0]) + netG.apply(weights_init) + return netG + +def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): + norm_layer = get_norm_layer(norm_type=norm) + netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) + print(netD) + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + netD.cuda(device_id=gpu_ids[0]) + netD.apply(weights_init) + return netD + +def print_network(net): + if isinstance(net, list): + net = net[0] + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + +############################################################################## +# Losses +############################################################################## +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + if target_is_real: + create_label = ((self.real_label_var is None) or + (self.real_label_var.numel() != input.numel())) + if create_label: + real_tensor = self.Tensor(input.size()).fill_(self.real_label) + self.real_label_var = Variable(real_tensor, requires_grad=False) + target_tensor = self.real_label_var + else: + create_label = ((self.fake_label_var is None) or + (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) + self.fake_label_var = Variable(fake_tensor, requires_grad=False) + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + if isinstance(input[0], list): + loss = 0 + for input_i in input: + pred = input_i[-1] + target_tensor = self.get_target_tensor(pred, target_is_real) + loss += self.loss(pred, target_tensor) + return loss + else: + target_tensor = self.get_target_tensor(input[-1], target_is_real) + return self.loss(input[-1], target_tensor) + +class VGGLoss(nn.Module): + def __init__(self, gpu_ids): + super(VGGLoss, self).__init__() + self.vgg = Vgg19().cuda() + self.criterion = nn.L1Loss() + self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] + + def forward(self, x, y): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(len(x_vgg)): + loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) + return loss + +############################################################################## +# Generator +############################################################################## +class LocalEnhancer(nn.Module): + def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9, + n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'): + super(LocalEnhancer, self).__init__() + self.n_local_enhancers = n_local_enhancers + + ###### global generator model ##### + ngf_global = ngf * (2**n_local_enhancers) + model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model + model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers + self.model = nn.Sequential(*model_global) + + ###### local enhancer layers ##### + for n in range(1, n_local_enhancers+1): + ### downsample + ngf_global = ngf * (2**(n_local_enhancers-n)) + model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), + norm_layer(ngf_global), nn.ReLU(True), + nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1), + norm_layer(ngf_global * 2), nn.ReLU(True)] + ### residual blocks + model_upsample = [] + for i in range(n_blocks_local): + model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)] + + ### upsample + model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1), + norm_layer(ngf_global), nn.ReLU(True)] + + ### final convolution + if n == n_local_enhancers: + model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] + + setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample)) + setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample)) + + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) + + def forward(self, input): + ### create input pyramid + input_downsampled = [input] + for i in range(self.n_local_enhancers): + input_downsampled.append(self.downsample(input_downsampled[-1])) + + ### output at coarest level + output_prev = self.model(input_downsampled[-1]) + ### build up one layer at a time + for n_local_enhancers in range(1, self.n_local_enhancers+1): + model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1') + model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2') + input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers] + output_prev = model_upsample(model_downsample(input_i) + output_prev) + return output_prev + +class GlobalGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect'): + assert(n_blocks >= 0) + super(GlobalGenerator, self).__init__() + activation = nn.ReLU(True) + + model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] + ### downsample + for i in range(n_downsampling): + mult = 2**i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), + norm_layer(ngf * mult * 2), activation] + + ### resnet blocks + mult = 2**n_downsampling + for i in range(n_blocks): + model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)] + + ### upsample + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), + norm_layer(int(ngf * mult / 2)), activation] + model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + +# Define a resnet block +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) + + def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim), + activation] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + +class Encoder(nn.Module): + def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): + super(Encoder, self).__init__() + self.output_nc = output_nc + + model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), nn.ReLU(True)] + ### downsample + for i in range(n_downsampling): + mult = 2**i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), + norm_layer(ngf * mult * 2), nn.ReLU(True)] + + ### upsample + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), + norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] + + model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*model) + + def forward(self, input, inst): + outputs = self.model(input) + + # instance-wise average pooling + outputs_mean = outputs.clone() + inst_list = np.unique(inst.cpu().numpy().astype(int)) + for i in inst_list: + indices = (inst == i).nonzero() # n x 4 + for j in range(self.output_nc): + output_ins = outputs[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]] + mean_feat = torch.mean(output_ins).expand_as(output_ins) + outputs_mean[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat + return outputs_mean + +class MultiscaleDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, + use_sigmoid=False, num_D=3, getIntermFeat=False): + super(MultiscaleDiscriminator, self).__init__() + self.num_D = num_D + self.n_layers = n_layers + self.getIntermFeat = getIntermFeat + + for i in range(num_D): + netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) + if getIntermFeat: + for j in range(n_layers+2): + setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) + else: + setattr(self, 'layer'+str(i), netD.model) + + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) + + def singleD_forward(self, model, input): + if self.getIntermFeat: + result = [input] + for i in range(len(model)): + result.append(model[i](result[-1])) + return result[1:] + else: + return [model(input)] + + def forward(self, input): + num_D = self.num_D + result = [] + input_downsampled = input + for i in range(num_D): + if self.getIntermFeat: + model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] + else: + model = getattr(self, 'layer'+str(num_D-1-i)) + result.append(self.singleD_forward(model, input_downsampled)) + if i != (num_D-1): + input_downsampled = self.downsample(input_downsampled) + return result + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [nn.Sigmoid()] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[1:] + else: + return self.model(input) + +from torchvision import models +class Vgg19(torch.nn.Module): + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py new file mode 100755 index 0000000..ba44e53 --- /dev/null +++ b/models/pix2pixHD_model.py @@ -0,0 +1,260 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import numpy as np +import torch +import os +from collections import OrderedDict +from torch.autograd import Variable +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks + +class Pix2PixHDModel(BaseModel): + def name(self): + return 'Pix2PixHDModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + if opt.resize_or_crop != 'none': # when training at full res this causes OOM + torch.backends.cudnn.benchmark = True + self.isTrain = opt.isTrain + self.use_features = opt.instance_feat or opt.label_feat + self.gen_features = self.use_features and not self.opt.load_features + + ##### define networks + # Generator network + netG_input_nc = opt.label_nc + if not opt.no_instance: + netG_input_nc += 1 + if self.use_features: + netG_input_nc += opt.feat_num + self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, + opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, + opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) + + # Discriminator network + if self.isTrain: + use_sigmoid = opt.no_lsgan + netD_input_nc = opt.label_nc + opt.output_nc + if not opt.no_instance: + netD_input_nc += 1 + self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, + opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) + + ### Encoder network + if self.gen_features: + self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', + opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) + + print('---------- Networks initialized -------------') + + # load networks + if not self.isTrain or opt.continue_train or opt.load_pretrain: + pretrained_path = '' if not self.isTrain else opt.load_pretrain + self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) + if self.isTrain: + self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) + if self.gen_features: + self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) + + # set loss functions and optimizers + if self.isTrain: + if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: + raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") + self.fake_pool = ImagePool(opt.pool_size) + self.old_lr = opt.lr + + # define loss functions + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionFeat = torch.nn.L1Loss() + if not opt.no_vgg_loss: + self.criterionVGG = networks.VGGLoss(self.gpu_ids) + + # Names so we can breakout loss + self.loss_names = ['G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake'] + + # initialize optimizers + # optimizer G + if opt.niter_fix_global > 0: + print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) + params_dict = dict(self.netG.named_parameters()) + params = [] + for key, value in params_dict.items(): + if key.startswith('model' + str(opt.n_local_enhancers)): + params += [{'params':[value],'lr':opt.lr}] + else: + params += [{'params':[value],'lr':0.0}] + else: + params = list(self.netG.parameters()) + if self.gen_features: + params += list(self.netE.parameters()) + self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + + # optimizer D + params = list(self.netD.parameters()) + self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + + def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): + # create one-hot vector for label map + size = label_map.size() + oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) + input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() + input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + + # get edges from instance map + if not self.opt.no_instance: + inst_map = inst_map.data.cuda() + edge_map = self.get_edges(inst_map) + input_label = torch.cat((input_label, edge_map), dim=1) + input_label = Variable(input_label, volatile=infer) + + # real images for training + if real_image is not None: + real_image = Variable(real_image.data.cuda()) + + # instance map for feature encoding + if self.use_features: + # get precomputed feature maps + if self.opt.load_features: + feat_map = Variable(feat_map.data.cuda()) + + return input_label, inst_map, real_image, feat_map + + def discriminate(self, input_label, test_image, use_pool=False): + input_concat = torch.cat((input_label, test_image.detach()), dim=1) + if use_pool: + fake_query = self.fake_pool.query(input_concat) + return self.netD.forward(fake_query) + else: + return self.netD.forward(input_concat) + + def forward(self, label, inst, image, feat, infer=False): + # Encode Inputs + input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) + + # Fake Generation + if self.use_features: + if not self.opt.load_features: + feat_map = self.netE.forward(real_image, inst_map) + input_concat = torch.cat((input_label, feat_map), dim=1) + else: + input_concat = input_label + fake_image = self.netG.forward(input_concat) + + # Fake Detection and Loss + pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) + loss_D_fake = self.criterionGAN(pred_fake_pool, False) + + # Real Detection and Loss + pred_real = self.discriminate(input_label, real_image) + loss_D_real = self.criterionGAN(pred_real, True) + + # GAN loss (Fake Passability Loss) + pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) + loss_G_GAN = self.criterionGAN(pred_fake, True) + + # GAN feature matching loss + loss_G_GAN_Feat = 0 + if not self.opt.no_ganFeat_loss: + feat_weights = 4.0 / (self.opt.n_layers_D + 1) + D_weights = 1.0 / self.opt.num_D + for i in range(self.opt.num_D): + for j in range(len(pred_fake[i])-1): + loss_G_GAN_Feat += D_weights * feat_weights * \ + self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat + + # VGG feature matching loss + loss_G_VGG = 0 + if not self.opt.no_vgg_loss: + loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat + + # Only return the fake_B image if necessary to save BW + return [ [ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ], None if not infer else fake_image ] + + def inference(self, label, inst): + # Encode Inputs + input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True) + + # Fake Generation + if self.use_features: + # sample clusters from precomputed features + feat_map = self.sample_features(inst_map) + input_concat = torch.cat((input_label, feat_map), dim=1) + else: + input_concat = input_label + fake_image = self.netG.forward(input_concat) + return fake_image + + def sample_features(self, inst): + # read precomputed feature clusters + cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) + features_clustered = np.load(cluster_path).item() + + # randomly sample from the feature clusters + inst_np = inst.cpu().numpy().astype(int) + feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3]) + for i in np.unique(inst_np): + label = i if i < 1000 else i//1000 + if label in features_clustered: + feat = features_clustered[label] + cluster_idx = np.random.randint(0, feat.shape[0]) + + idx = (inst == i).nonzero() + for k in range(self.opt.feat_num): + feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] + return feat_map + + def encode_features(self, image, inst): + image = Variable(image.cuda(), volatile=True) + feat_num = self.opt.feat_num + h, w = inst.size()[2], inst.size()[3] + block_num = 32 + feat_map = self.netE.forward(image, inst.cuda()) + inst_np = inst.cpu().numpy().astype(int) + feature = {} + for i in range(self.opt.label_nc): + feature[i] = np.zeros((0, feat_num+1)) + for i in np.unique(inst_np): + label = i if i < 1000 else i//1000 + idx = (inst == i).nonzero() + num = idx.size()[0] + idx = idx[num//2,:] + val = np.zeros((1, feat_num+1)) + for k in range(feat_num): + val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] + val[0, feat_num] = float(num) / (h * w // block_num) + feature[label] = np.append(feature[label], val, axis=0) + return feature + + def get_edges(self, t): + edge = torch.cuda.ByteTensor(t.size()).zero_() + edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) + edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) + edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) + edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) + return edge.float() + + def save(self, which_epoch): + self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) + self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) + if self.gen_features: + self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) + + def update_fixed_params(self): + # after fixing the global generator for a number of iterations, also start finetuning it + params = list(self.netG.parameters()) + if self.gen_features: + params += list(self.netE.parameters()) + self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) + print('------------ Now also finetuning global generator -----------') + + def update_learning_rate(self): + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr diff --git a/options/__init__.py b/options/__init__.py new file mode 100755 index 0000000..e69de29 --- /dev/null +++ b/options/__init__.py diff --git a/options/base_options.py b/options/base_options.py new file mode 100755 index 0000000..863c061 --- /dev/null +++ b/options/base_options.py @@ -0,0 +1,95 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import argparse +import os +from util import util +import torch + +class BaseOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser() + self.initialized = False + + def initialize(self): + # experiment specifics + self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models') + self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') + self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') + + # input/output sizes + self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') + self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size') + self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') + self.parser.add_argument('--label_nc', type=int, default=35, help='# of input image channels') + self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') + + # for setting inputs + self.parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/') + self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') + self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') + self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') + self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + + # for displays + self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size') + self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') + + # for generator + self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG') + self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') + self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG') + self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network') + self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network') + self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use') + self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer') + + # for instance-wise features + self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input') + self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input') + self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input') + self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features') + self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps') + self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder') + self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') + self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features') + + self.initialized = True + + def parse(self, save=True): + if not self.initialized: + self.initialize() + self.opt = self.parser.parse_args() + self.opt.isTrain = self.isTrain # train or test + + str_ids = self.opt.gpu_ids.split(',') + self.opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + self.opt.gpu_ids.append(id) + + # set gpu ids + if len(self.opt.gpu_ids) > 0: + torch.cuda.set_device(self.opt.gpu_ids[0]) + + args = vars(self.opt) + + print('------------ Options -------------') + for k, v in sorted(args.items()): + print('%s: %s' % (str(k), str(v))) + print('-------------- End ----------------') + + # save to the disk + expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) + util.mkdirs(expr_dir) + if save and not self.opt.continue_train: + file_name = os.path.join(expr_dir, 'opt.txt') + with open(file_name, 'wt') as opt_file: + opt_file.write('------------ Options -------------\n') + for k, v in sorted(args.items()): + opt_file.write('%s: %s\n' % (str(k), str(v))) + opt_file.write('-------------- End ----------------\n') + return self.opt diff --git a/options/test_options.py b/options/test_options.py new file mode 100755 index 0000000..aaeff53 --- /dev/null +++ b/options/test_options.py @@ -0,0 +1,15 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +from .base_options import BaseOptions + +class TestOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') + self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') + self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') + self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') + self.isTrain = False diff --git a/options/train_options.py b/options/train_options.py new file mode 100755 index 0000000..9994a4a --- /dev/null +++ b/options/train_options.py @@ -0,0 +1,36 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +from .base_options import BaseOptions + +class TrainOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + # for displays + self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') + self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results') + self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') + self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') + + # for training + self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') + self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') + self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') + self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') + + # for discriminators + self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') + self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') + self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') + self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') + self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') + self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') + self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') + self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') + + self.isTrain = True diff --git a/precompute_feature_maps.py b/precompute_feature_maps.py new file mode 100755 index 0000000..a631b9c --- /dev/null +++ b/precompute_feature_maps.py @@ -0,0 +1,36 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+from options.train_options import TrainOptions
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+import numpy as np
+import os, time
+import util.util as util
+from torch.autograd import Variable
+import torch.nn as nn
+
+opt = TrainOptions().parse()
+opt.nThreads = 1
+opt.batchSize = 1
+opt.serial_batches = True
+opt.no_flip = True
+opt.instance_feat = True
+
+name = 'features'
+save_path = os.path.join(opt.checkpoints_dir, opt.name)
+
+############ Initialize #########
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+dataset_size = len(data_loader)
+model = create_model(opt)
+util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat'))
+
+######## Save precomputed feature maps for 1024p training #######
+for i, data in enumerate(dataset):
+ print('%d / %d images' % (i+1, dataset_size))
+ feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda())
+ feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map)
+ image_numpy = util.tensor2im(feat_map.data[0])
+ save_path = data['path'][0].replace('/train_label/', '/train_feat/')
+ util.save_image(image_numpy, save_path)
\ No newline at end of file diff --git a/scripts/test_1024p.sh b/scripts/test_1024p.sh new file mode 100755 index 0000000..99c1e24 --- /dev/null +++ b/scripts/test_1024p.sh @@ -0,0 +1,3 @@ +################################ Testing ################################
+# labels only
+python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none
\ No newline at end of file diff --git a/scripts/test_1024p_feat.sh b/scripts/test_1024p_feat.sh new file mode 100755 index 0000000..2f4ba17 --- /dev/null +++ b/scripts/test_1024p_feat.sh @@ -0,0 +1,5 @@ +################################ Testing ################################
+# first precompute and cluster all features
+python encode_features.py --name label2city_1024p_feat --netG local --ngf 32 --resize_or_crop none;
+# use instance-wise features
+python test.py --name label2city_1024p_feat ---netG local --ngf 32 --resize_or_crop none --instance_feat
\ No newline at end of file diff --git a/scripts/test_512p.sh b/scripts/test_512p.sh new file mode 100755 index 0000000..3131043 --- /dev/null +++ b/scripts/test_512p.sh @@ -0,0 +1,3 @@ +################################ Testing ################################
+# labels only
+python test.py --name label2city_512p
\ No newline at end of file diff --git a/scripts/test_512p_feat.sh b/scripts/test_512p_feat.sh new file mode 100755 index 0000000..8f25e9c --- /dev/null +++ b/scripts/test_512p_feat.sh @@ -0,0 +1,5 @@ +################################ Testing ################################
+# first precompute and cluster all features
+python encode_features.py --name label2city_512p_feat;
+# use instance-wise features
+python test.py --name label2city_512p_feat --instance_feat
\ No newline at end of file diff --git a/scripts/train_1024p_12G.sh b/scripts/train_1024p_12G.sh new file mode 100755 index 0000000..d5ea7d7 --- /dev/null +++ b/scripts/train_1024p_12G.sh @@ -0,0 +1,4 @@ +############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
+##### Using GPUs with 12G memory (not tested)
+# Using labels only
+python train.py --name label2city_1024p --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p/ --niter_fix_global 20 --resize_or_crop crop --fineSize 1024
\ No newline at end of file diff --git a/scripts/train_1024p_24G.sh b/scripts/train_1024p_24G.sh new file mode 100755 index 0000000..88e58f7 --- /dev/null +++ b/scripts/train_1024p_24G.sh @@ -0,0 +1,4 @@ +############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
+######## Using GPUs with 24G memory
+# Using labels only
+python train.py --name label2city_1024p --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p/ --niter 50 --niter_decay 50 --niter_fix_global 10 --resize_or_crop none
\ No newline at end of file diff --git a/scripts/train_1024p_feat_12G.sh b/scripts/train_1024p_feat_12G.sh new file mode 100755 index 0000000..f8e3d61 --- /dev/null +++ b/scripts/train_1024p_feat_12G.sh @@ -0,0 +1,6 @@ +############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
+##### Using GPUs with 12G memory (not tested)
+# First precompute feature maps and save them
+python precompute_feature_maps.py --name label2city_512p_feat;
+# Adding instances and encoded features
+python train.py --name label2city_1024p_feat --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p_feat/ --niter_fix_global 20 --resize_or_crop crop --fineSize 896 --instance_feat --load_features
\ No newline at end of file diff --git a/scripts/train_1024p_feat_24G.sh b/scripts/train_1024p_feat_24G.sh new file mode 100755 index 0000000..399d720 --- /dev/null +++ b/scripts/train_1024p_feat_24G.sh @@ -0,0 +1,6 @@ +############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
+######## Using GPUs with 24G memory
+# First precompute feature maps and save them
+python precompute_feature_maps.py --name label2city_512p_feat;
+# Adding instances and encoded features
+python train.py --name label2city_1024p_feat --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p_feat/ --niter 50 --niter_decay 50 --niter_fix_global 10 --resize_or_crop none --instance_feat --load_features
\ No newline at end of file diff --git a/scripts/train_512p.sh b/scripts/train_512p.sh new file mode 100755 index 0000000..222c348 --- /dev/null +++ b/scripts/train_512p.sh @@ -0,0 +1,2 @@ +### Using labels only
+python train.py --name label2city_512p
\ No newline at end of file diff --git a/scripts/train_512p_feat.sh b/scripts/train_512p_feat.sh new file mode 100755 index 0000000..9d4859c --- /dev/null +++ b/scripts/train_512p_feat.sh @@ -0,0 +1,2 @@ +### Adding instances and encoded features
+python train.py --name label2city_512p_feat --instance_feat
\ No newline at end of file diff --git a/scripts/train_512p_multigpu.sh b/scripts/train_512p_multigpu.sh new file mode 100755 index 0000000..16f0a1a --- /dev/null +++ b/scripts/train_512p_multigpu.sh @@ -0,0 +1,2 @@ +######## Multi-GPU training example #######
+python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7
\ No newline at end of file @@ -0,0 +1,37 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import time +import os +from collections import OrderedDict +from options.test_options import TestOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +import util.util as util +from util.visualizer import Visualizer +from util import html + +opt = TestOptions().parse(save=False) +opt.nThreads = 1 # test code only supports nThreads = 1 +opt.batchSize = 1 # test code only supports batchSize = 1 +opt.serial_batches = True # no shuffle +opt.no_flip = True # no flip + +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +model = create_model(opt) +visualizer = Visualizer(opt) +# create website +web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) +webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) +# test +for i, data in enumerate(dataset): + if i >= opt.how_many: + break + generated = model.inference(data['label'], data['inst']) + visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), + ('synthesized_image', util.tensor2im(generated.data[0]))]) + img_path = data['path'] + print('process image... %s' % img_path) + visualizer.save_images(webpage, visuals, img_path) + +webpage.save() diff --git a/train.py b/train.py new file mode 100755 index 0000000..4965481 --- /dev/null +++ b/train.py @@ -0,0 +1,118 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import time +from collections import OrderedDict +from options.train_options import TrainOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +import util.util as util +from util.visualizer import Visualizer +import os +import numpy as np +import torch +from torch.autograd import Variable + +opt = TrainOptions().parse() +iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') +if opt.continue_train: + try: + start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) + except: + start_epoch, epoch_iter = 1, 0 + print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) +else: + start_epoch, epoch_iter = 1, 0 + +if opt.debug: + opt.display_freq = 1 + opt.print_freq = 1 + opt.niter = 1 + opt.niter_decay = 0 + opt.max_dataset_size = 10 + +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +dataset_size = len(data_loader) +print('#training images = %d' % dataset_size) + +model = create_model(opt) +visualizer = Visualizer(opt) + +total_steps = (start_epoch-1) * dataset_size + epoch_iter +for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): + epoch_start_time = time.time() + if epoch != start_epoch: + epoch_iter = epoch_iter % dataset_size + for i, data in enumerate(dataset, start=epoch_iter): + iter_start_time = time.time() + total_steps += opt.batchSize + epoch_iter += opt.batchSize + + # whether to collect output images + save_fake = total_steps % opt.display_freq == 0 + + ############## Forward Pass ###################### + losses, generated = model(Variable(data['label']), Variable(data['inst']), + Variable(data['image']), Variable(data['feat']), infer=save_fake) + + # sum per device losses + losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] + loss_dict = dict(zip(model.module.loss_names, losses)) + + # calculate final loss scalar + loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 + loss_G = loss_dict['G_GAN'] + loss_dict['G_GAN_Feat'] + loss_dict['G_VGG'] + + ############### Backward Pass #################### + # update generator weights + model.module.optimizer_G.zero_grad() + loss_G.backward() + model.module.optimizer_G.step() + + # update discriminator weights + model.module.optimizer_D.zero_grad() + loss_D.backward() + model.module.optimizer_D.step() + + #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) + + ############## Display results and errors ########## + ### print out errors + if total_steps % opt.print_freq == 0: + errors = {k: v.data[0] if not isinstance(v, (int,long,float)) else v for k, v in loss_dict.items()} + t = (time.time() - iter_start_time) / opt.batchSize + visualizer.print_current_errors(epoch, epoch_iter, errors, t) + visualizer.plot_current_errors(errors, total_steps) + + ### display output images + if save_fake: + visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), + ('synthesized_image', util.tensor2im(generated.data[0])), + ('real_image', util.tensor2im(data['image'][0]))]) + visualizer.display_current_results(visuals, epoch, total_steps) + + ### save latest model + if total_steps % opt.save_latest_freq == 0: + print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) + model.module.save('latest') + np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') + + # end of epoch + iter_end_time = time.time() + print('End of epoch %d / %d \t Time Taken: %d sec' % + (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) + + ### save model for this epoch + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) + model.module.save('latest') + model.module.save(epoch) + np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') + + ### instead of only training the local enhancer, train the entire network after certain iterations + if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): + model.module.update_fixed_params() + + ### linearly decay learning rate after certain iterations + if epoch > opt.niter: + model.module.update_learning_rate() diff --git a/util/__init__.py b/util/__init__.py new file mode 100755 index 0000000..e69de29 --- /dev/null +++ b/util/__init__.py diff --git a/util/html.py b/util/html.py new file mode 100755 index 0000000..a80aa59 --- /dev/null +++ b/util/html.py @@ -0,0 +1,63 @@ +import dominate +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, reflesh=0): + self.title = title + self.web_dir = web_dir + self.img_dir = os.path.join(self.web_dir, 'images') + if not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + + self.doc = dominate.document(title=title) + if reflesh > 0: + with self.doc.head: + meta(http_equiv="reflesh", content=str(reflesh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=512): + self.add_table() + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % (width), src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.jpg' % n) + txts.append('text_%d' % n) + links.append('image_%d.jpg' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/util/image_pool.py b/util/image_pool.py new file mode 100755 index 0000000..152ef5b --- /dev/null +++ b/util/image_pool.py @@ -0,0 +1,32 @@ +import random +import numpy as np +import torch +from torch.autograd import Variable +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images.data: + image = torch.unsqueeze(image, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size-1) + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = Variable(torch.cat(return_images, 0)) + return return_images diff --git a/util/util.py b/util/util.py new file mode 100755 index 0000000..0898f7a --- /dev/null +++ b/util/util.py @@ -0,0 +1,99 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import inspect, re +import numpy as np +import os +import collections +from PIL import Image + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8, normalize=True): + if isinstance(image_tensor, list): + image_numpy = [] + for i in range(len(image_tensor)): + image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) + return image_numpy + image_numpy = image_tensor.cpu().float().numpy() + if normalize: + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + else: + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + image_numpy = np.clip(image_numpy, 0, 255) + if image_numpy.shape[2] == 1: + image_numpy = image_numpy[:,:,0] + return image_numpy.astype(imtype) + +def tensor2label(output, n_label, imtype=np.uint8): + output = output.cpu().float() + if output.size()[0] > 1: + output = output.max(0, keepdim=True)[1] + output = Colorize(n_label)(output) + output = np.transpose(output.numpy(), (1, 2, 0)) + return output.astype(imtype) + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + +def uint82bin(n, count=8): + """returns the binary of integer n, count refers to amount of bits""" + return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) + +def labelcolormap(N): + if N == 35: # cityscape + cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81), + (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153), + (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0), + (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70), + ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)], + dtype=np.uint8) + else: + cmap = np.zeros((N, 3), dtype=np.uint8) + for i in range(N): + r = 0 + g = 0 + b = 0 + id = i + for j in range(7): + str_id = uint82bin(id) + r = r ^ (np.uint8(str_id[-1]) << (7-j)) + g = g ^ (np.uint8(str_id[-2]) << (7-j)) + b = b ^ (np.uint8(str_id[-3]) << (7-j)) + id = id >> 3 + cmap[i, 0] = r + cmap[i, 1] = g + cmap[i, 2] = b + return cmap + +class Colorize(object): + def __init__(self, n=35): + self.cmap = labelcolormap(n) + self.cmap = torch.from_numpy(self.cmap[:n]) + + def __call__(self, gray_image): + size = gray_image.size() + color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) + + for label in range(0, len(self.cmap)): + mask = (label == gray_image[0]).cpu() + color_image[0][mask] = self.cmap[label][0] + color_image[1][mask] = self.cmap[label][1] + color_image[2][mask] = self.cmap[label][2] + + return color_image diff --git a/util/visualizer.py b/util/visualizer.py new file mode 100755 index 0000000..f41c55a --- /dev/null +++ b/util/visualizer.py @@ -0,0 +1,133 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import numpy as np +import os +import ntpath +import time +from . import util +from . import html +import scipy.misc +try: + from StringIO import StringIO # Python 2.7 +except ImportError: + from io import BytesIO # Python 3.x + +class Visualizer(): + def __init__(self, opt): + # self.opt = opt + self.tf_log = opt.tf_log + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + if self.tf_log: + import tensorflow as tf + self.tf = tf + self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') + self.writer = tf.summary.FileWriter(self.log_dir) + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, step): + if self.tf_log: # show images in tensorboard output + img_summaries = [] + for label, image_numpy in visuals.items(): + # Write the image to a string + try: + s = StringIO() + except: + s = BytesIO() + scipy.misc.toimage(image_numpy).save(s, format="jpeg") + # Create an Image object + img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) + # Create a Summary value + img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) + + # Create and write Summary + summary = self.tf.Summary(value=img_summaries) + self.writer.add_summary(summary, step) + + if self.use_html: # save images to a html file + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) + util.save_image(image_numpy[i], img_path) + else: + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i) + ims.append(img_path) + txts.append(label+str(i)) + links.append(img_path) + else: + img_path = 'epoch%.3d_%s.jpg' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + if len(ims) < 10: + webpage.add_images(ims, txts, links, width=self.win_size) + else: + num = int(round(len(ims)/2.0)) + webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) + webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) + webpage.save() + + # errors: dictionary of error labels and values + def plot_current_errors(self, errors, step): + if self.tf_log: + for tag, value in errors.items(): + summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) + self.writer.add_summary(summary, step) + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, epoch, i, errors, t): + message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) + for k, v in errors.items(): + if v != 0: + message += '%s: %.3f ' % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + + # save image to the disk + def save_images(self, webpage, visuals, image_path): + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + image_name = '%s_%s.jpg' % (name, label) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=self.win_size) |
