summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-x.gitignore40
-rwxr-xr-xLICENSE.txt43
-rwxr-xr-xREADME.md71
-rwxr-xr-x_config.yml1
-rwxr-xr-xdata/__init__.py0
-rwxr-xr-xdata/aligned_dataset.py76
-rwxr-xr-xdata/base_data_loader.py14
-rwxr-xr-xdata/base_dataset.py92
-rwxr-xr-xdata/custom_dataset_data_loader.py31
-rwxr-xr-xdata/data_loader.py7
-rwxr-xr-xdata/image_folder.py68
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000000_000576_gtFine_instanceIds.pngbin0 -> 19695 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000000_001236_gtFine_instanceIds.pngbin0 -> 23700 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000000_003357_gtFine_instanceIds.pngbin0 -> 21144 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000000_011810_gtFine_instanceIds.pngbin0 -> 29329 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000000_012868_gtFine_instanceIds.pngbin0 -> 28880 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000001_013710_gtFine_instanceIds.pngbin0 -> 26359 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000001_015328_gtFine_instanceIds.pngbin0 -> 22778 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000001_023769_gtFine_instanceIds.pngbin0 -> 28386 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000001_028335_gtFine_instanceIds.pngbin0 -> 29047 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000001_032711_gtFine_instanceIds.pngbin0 -> 19537 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000001_033655_gtFine_instanceIds.pngbin0 -> 20295 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000001_042733_gtFine_instanceIds.pngbin0 -> 25782 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000001_047552_gtFine_instanceIds.pngbin0 -> 22371 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000001_054640_gtFine_instanceIds.pngbin0 -> 30933 bytes
-rwxr-xr-xdatasets/cityscapes/test_inst/frankfurt_000001_055387_gtFine_instanceIds.pngbin0 -> 29306 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000000_000576_gtFine_labelIds.pngbin0 -> 14775 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000000_001236_gtFine_labelIds.pngbin0 -> 18756 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000000_003357_gtFine_labelIds.pngbin0 -> 16732 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000000_011810_gtFine_labelIds.pngbin0 -> 23273 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000000_012868_gtFine_labelIds.pngbin0 -> 23056 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000001_013710_gtFine_labelIds.pngbin0 -> 21643 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000001_015328_gtFine_labelIds.pngbin0 -> 15809 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000001_023769_gtFine_labelIds.pngbin0 -> 22669 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000001_028335_gtFine_labelIds.pngbin0 -> 23054 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000001_032711_gtFine_labelIds.pngbin0 -> 13258 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000001_033655_gtFine_labelIds.pngbin0 -> 14370 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000001_042733_gtFine_labelIds.pngbin0 -> 21005 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000001_047552_gtFine_labelIds.pngbin0 -> 16416 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000001_054640_gtFine_labelIds.pngbin0 -> 23085 bytes
-rwxr-xr-xdatasets/cityscapes/test_label/frankfurt_000001_055387_gtFine_labelIds.pngbin0 -> 20753 bytes
-rwxr-xr-xdatasets/cityscapes/train_img/aachen_000000_000019_leftImg8bit.pngbin0 -> 2134771 bytes
-rwxr-xr-xdatasets/cityscapes/train_img/aachen_000001_000019_leftImg8bit.pngbin0 -> 2138159 bytes
-rwxr-xr-xdatasets/cityscapes/train_img/aachen_000002_000019_leftImg8bit.pngbin0 -> 2167234 bytes
-rwxr-xr-xdatasets/cityscapes/train_img/aachen_000003_000019_leftImg8bit.pngbin0 -> 2216193 bytes
-rwxr-xr-xdatasets/cityscapes/train_img/aachen_000004_000019_leftImg8bit.pngbin0 -> 2176021 bytes
-rwxr-xr-xdatasets/cityscapes/train_inst/aachen_000000_000019_gtFine_instanceIds.pngbin0 -> 23929 bytes
-rwxr-xr-xdatasets/cityscapes/train_inst/aachen_000001_000019_gtFine_instanceIds.pngbin0 -> 26610 bytes
-rwxr-xr-xdatasets/cityscapes/train_inst/aachen_000002_000019_gtFine_instanceIds.pngbin0 -> 18654 bytes
-rwxr-xr-xdatasets/cityscapes/train_inst/aachen_000003_000019_gtFine_instanceIds.pngbin0 -> 15081 bytes
-rwxr-xr-xdatasets/cityscapes/train_inst/aachen_000004_000019_gtFine_instanceIds.pngbin0 -> 23201 bytes
-rwxr-xr-xdatasets/cityscapes/train_label/aachen_000000_000019_gtFine_labelIds.pngbin0 -> 19710 bytes
-rwxr-xr-xdatasets/cityscapes/train_label/aachen_000001_000019_gtFine_labelIds.pngbin0 -> 22659 bytes
-rwxr-xr-xdatasets/cityscapes/train_label/aachen_000002_000019_gtFine_labelIds.pngbin0 -> 14956 bytes
-rwxr-xr-xdatasets/cityscapes/train_label/aachen_000003_000019_gtFine_labelIds.pngbin0 -> 11370 bytes
-rwxr-xr-xdatasets/cityscapes/train_label/aachen_000004_000019_gtFine_labelIds.pngbin0 -> 18472 bytes
-rwxr-xr-xencode_features.py57
-rwxr-xr-xmodels/__init__.py0
-rwxr-xr-xmodels/base_model.py86
-rwxr-xr-xmodels/models.py14
-rwxr-xr-xmodels/networks.py421
-rwxr-xr-xmodels/pix2pixHD_model.py260
-rwxr-xr-xoptions/__init__.py0
-rwxr-xr-xoptions/base_options.py95
-rwxr-xr-xoptions/test_options.py15
-rwxr-xr-xoptions/train_options.py36
-rwxr-xr-xprecompute_feature_maps.py36
-rwxr-xr-xscripts/test_1024p.sh3
-rwxr-xr-xscripts/test_1024p_feat.sh5
-rwxr-xr-xscripts/test_512p.sh3
-rwxr-xr-xscripts/test_512p_feat.sh5
-rwxr-xr-xscripts/train_1024p_12G.sh4
-rwxr-xr-xscripts/train_1024p_24G.sh4
-rwxr-xr-xscripts/train_1024p_feat_12G.sh6
-rwxr-xr-xscripts/train_1024p_feat_24G.sh6
-rwxr-xr-xscripts/train_512p.sh2
-rwxr-xr-xscripts/train_512p_feat.sh2
-rwxr-xr-xscripts/train_512p_multigpu.sh2
-rwxr-xr-xtest.py37
-rwxr-xr-xtrain.py118
-rwxr-xr-xutil/__init__.py0
-rwxr-xr-xutil/html.py63
-rwxr-xr-xutil/image_pool.py32
-rwxr-xr-xutil/util.py99
-rwxr-xr-xutil/visualizer.py133
85 files changed, 1982 insertions, 5 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100755
index 0000000..681efd0
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,40 @@
+debug*
+checkpoints/
+results/
+build/
+dist/
+torch.egg-info/
+*/**/__pycache__
+torch/version.py
+torch/csrc/generic/TensorMethods.cpp
+torch/lib/*.so*
+torch/lib/*.dylib*
+torch/lib/*.h
+torch/lib/build
+torch/lib/tmp_install
+torch/lib/include
+torch/lib/torch_shm_manager
+torch/csrc/cudnn/cuDNN.cpp
+torch/csrc/nn/THNN.cwrap
+torch/csrc/nn/THNN.cpp
+torch/csrc/nn/THCUNN.cwrap
+torch/csrc/nn/THCUNN.cpp
+torch/csrc/nn/THNN_generic.cwrap
+torch/csrc/nn/THNN_generic.cpp
+torch/csrc/nn/THNN_generic.h
+docs/src/**/*
+test/data/legacy_modules.t7
+test/data/gpu_tensors.pt
+test/htmlcov
+test/.coverage
+*/*.pyc
+*/**/*.pyc
+*/**/**/*.pyc
+*/**/**/**/*.pyc
+*/**/**/**/**/*.pyc
+*/*.so*
+*/**/*.so*
+*/**/*.dylib*
+test/data/legacy_serialized.pt
+*.DS_Store
+*~
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100755
index 0000000..7406cd7
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,43 @@
+Copyright (C) 2017 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
+All rights reserved.
+Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+
+Permission to use, copy, modify, and distribute this software and its documentation
+for any non-commercial purpose is hereby granted without fee, provided that the above
+copyright notice appear in all copies and that both that copyright notice and this
+permission notice appear in supporting documentation, and that the name of the author
+not be used in advertising or publicity pertaining to distribution of the software
+without specific, written prior permission.
+
+THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
+IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
+DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
+WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
+OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+
+--------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ----------------
+Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/README.md b/README.md
index fc13ade..c560c46 100755
--- a/README.md
+++ b/README.md
@@ -10,10 +10,6 @@ Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photor
<sup>1</sup>NVIDIA Corporation, <sup>2</sup>UC Berkeley
In arxiv, 2017.
-## Release notice
-The code is ready to publish but still under final approval process. It should be approved in a couple of days.<br>
-If you want to get notified once the code is released, please subscribe [here](https://tcwang0509.github.io/pix2pixHD/subscribe.html).
-
## Image-to-image translation at 2k/1k resolution
- Our label-to-streetview results
<p align='center'>
@@ -53,7 +49,69 @@ If you want to get notified once the code is released, please subscribe [here](h
<img src='imgs/face_short.gif' width='490'/>
</p>
-### Citation
+## Prerequisites
+- Linux or macOS
+- Python 2 or 3
+- NVIDIA GPU (12G or 24G memory) + CUDA cuDNN
+
+## Getting Started
+### Installation
+- Install PyTorch and dependencies from http://pytorch.org
+- Install python libraries [dominate](https://github.com/Knio/dominate).
+```bash
+pip install dominate
+```
+- Clone this repo:
+```bash
+git clone https://github.com/NVIDIA/pix2pixHD
+cd pix2pixHD
+```
+
+
+### Testing
+- A few example Cityscapes test images are included in the `datasets` folder.
+- Please download the pre-trained Cityscapes model from [here](https://drive.google.com/file/d/1h9SykUnuZul7J3Nbms2QGH1wa85nbN2-/view?usp=sharing) (google drive link), and put it under `./checkpoints/label2city_1024p/`
+- Test the model (`bash ./scripts/test_1024p.sh`):
+```bash
+#!./scripts/test_1024p.sh
+python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none
+```
+The test results will be saved to a html file here: `./results/label2city_1024p/test_latest/index.html`.
+
+More example scripts can be found in the `scripts` directory.
+
+
+### Dataset
+- We use the Cityscapes dataset. To train a model on the full dataset, please download it from the [official website](https://www.cityscapes-dataset.com/) (registration required).
+After downloading, please put it under the `datasets` folder in the same way the example images are provided.
+
+
+### Training
+- Train a model at 1024 x 512 resolution (`bash ./scripts/train_512p.sh`):
+```bash
+#!./scripts/train_512p.sh
+python train.py --name label2city_512p
+```
+- To view training results, please checkout intermediate results in `./checkpoints/label2city_512p/web/index.html`.
+If you have tensorflow installed, you can see tensorboard logs in `./checkpoints/label2city_512p/logs` by adding `--tf_log` to the training scripts.
+
+### Multi-GPU training
+- Train a model using multiple GPUs (`bash ./scripts/train_512p_multigpu.sh`):
+```bash
+#!./scripts/train_512p_multigpu.sh
+python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7
+```
+Note: this is not tested and we trained our model using single GPU only. Please use at your own discretion.
+
+### Training at full resolution
+- To train the images at full resolution (2048 x 1024) requires a GPU with 24G memory (`bash ./scripts/train_1024p_24G.sh`).
+If only GPUs with 12G memory are available, please use the 12G script (`bash ./scripts/train_1024p_12G.sh`), which will crop the images during training. Performance is not guaranteed using this script.
+
+## More Training/test Details
+- Flags: see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags.
+
+
+## Citation
If you find this useful for your research, please use the following.
@@ -65,3 +123,6 @@ If you find this useful for your research, please use the following.
year={2017}
}
```
+
+## Acknowledgments
+This code borrows heavily from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
diff --git a/_config.yml b/_config.yml
new file mode 100755
index 0000000..2f7efbe
--- /dev/null
+++ b/_config.yml
@@ -0,0 +1 @@
+theme: jekyll-theme-minimal \ No newline at end of file
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100755
index 0000000..e69de29
--- /dev/null
+++ b/data/__init__.py
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
new file mode 100755
index 0000000..50390f3
--- /dev/null
+++ b/data/aligned_dataset.py
@@ -0,0 +1,76 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import os.path
+import random
+import torchvision.transforms as transforms
+import torch
+from data.base_dataset import BaseDataset, get_params, get_transform, normalize
+from data.image_folder import make_dataset
+from PIL import Image
+import numpy as np
+
+class AlignedDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+
+ ### label maps
+ self.dir_label = os.path.join(opt.dataroot, opt.phase + '_label')
+ self.label_paths = sorted(make_dataset(self.dir_label))
+
+ ### real images
+ if opt.isTrain:
+ self.dir_image = os.path.join(opt.dataroot, opt.phase + '_img')
+ self.image_paths = sorted(make_dataset(self.dir_image))
+
+ ### instance maps
+ if not opt.no_instance:
+ self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
+ self.inst_paths = sorted(make_dataset(self.dir_inst))
+
+ ### load precomputed instance-wise encoded features
+ if opt.load_features:
+ self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
+ print('----------- loading features from %s ----------' % self.dir_feat)
+ self.feat_paths = sorted(make_dataset(self.dir_feat))
+
+ self.dataset_size = len(self.label_paths)
+
+ def __getitem__(self, index):
+ ### label maps
+ label_path = self.label_paths[index]
+ label = Image.open(label_path)
+ params = get_params(self.opt, label.size)
+ transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
+ label_tensor = transform_label(label) * 255.0
+
+ image_tensor = inst_tensor = feat_tensor = 0
+ ### real images
+ if self.opt.isTrain:
+ image_path = self.image_paths[index]
+ image = Image.open(image_path).convert('RGB')
+ transform_image = get_transform(self.opt, params)
+ image_tensor = transform_image(image)
+
+ ### if using instance maps
+ if not self.opt.no_instance:
+ inst_path = self.inst_paths[index]
+ inst = Image.open(inst_path)
+ inst_tensor = transform_label(inst)
+
+ if self.opt.load_features:
+ feat_path = self.feat_paths[index]
+ feat = Image.open(feat_path).convert('RGB')
+ norm = normalize()
+ feat_tensor = norm(transform_label(feat))
+
+ input_dict = {'label': label_tensor, 'inst': inst_tensor, 'image': image_tensor,
+ 'feat': feat_tensor, 'path': label_path}
+
+ return input_dict
+
+ def __len__(self):
+ return len(self.label_paths)
+
+ def name(self):
+ return 'AlignedDataset' \ No newline at end of file
diff --git a/data/base_data_loader.py b/data/base_data_loader.py
new file mode 100755
index 0000000..0e1deb5
--- /dev/null
+++ b/data/base_data_loader.py
@@ -0,0 +1,14 @@
+
+class BaseDataLoader():
+ def __init__(self):
+ pass
+
+ def initialize(self, opt):
+ self.opt = opt
+ pass
+
+ def load_data():
+ return None
+
+
+
diff --git a/data/base_dataset.py b/data/base_dataset.py
new file mode 100755
index 0000000..038d3d2
--- /dev/null
+++ b/data/base_dataset.py
@@ -0,0 +1,92 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import torch.utils.data as data
+from PIL import Image
+import torchvision.transforms as transforms
+import numpy as np
+import random
+
+class BaseDataset(data.Dataset):
+ def __init__(self):
+ super(BaseDataset, self).__init__()
+
+ def name(self):
+ return 'BaseDataset'
+
+ def initialize(self, opt):
+ pass
+
+def get_params(opt, size):
+ w, h = size
+ new_h = h
+ new_w = w
+ if opt.resize_or_crop == 'resize_and_crop':
+ new_h = new_w = opt.loadSize
+ elif opt.resize_or_crop == 'scale_width_and_crop':
+ new_w = opt.loadSize
+ new_h = opt.loadSize * h // w
+
+ x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
+ y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
+
+ flip = random.random() > 0.5
+ return {'crop_pos': (x, y), 'flip': flip}
+
+def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
+ transform_list = []
+ if 'resize' in opt.resize_or_crop:
+ osize = [opt.loadSize, opt.loadSize]
+ transform_list.append(transforms.Scale(osize, method))
+ elif 'scale_width' in opt.resize_or_crop:
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
+
+ if 'crop' in opt.resize_or_crop:
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
+
+ if opt.resize_or_crop == 'none':
+ base = float(2 ** opt.n_downsample_global)
+ if opt.netG == 'local':
+ base *= (2 ** opt.n_local_enhancers)
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
+
+ if opt.isTrain and not opt.no_flip:
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
+
+ transform_list += [transforms.ToTensor()]
+
+ if normalize:
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+ return transforms.Compose(transform_list)
+
+def normalize():
+ return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+
+def __make_power_2(img, base, method=Image.BICUBIC):
+ ow, oh = img.size
+ h = int(round(oh / base) * base)
+ w = int(round(ow / base) * base)
+ if (h == oh) and (w == ow):
+ return img
+ return img.resize((w, h), method)
+
+def __scale_width(img, target_width, method=Image.BICUBIC):
+ ow, oh = img.size
+ if (ow == target_width):
+ return img
+ w = target_width
+ h = int(target_width * oh / ow)
+ return img.resize((w, h), method)
+
+def __crop(img, pos, size):
+ ow, oh = img.size
+ x1, y1 = pos
+ tw = th = size
+ if (ow > tw or oh > th):
+ return img.crop((x1, y1, x1 + tw, y1 + th))
+ return img
+
+def __flip(img, flip):
+ if flip:
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
+ return img
diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py
new file mode 100755
index 0000000..0b98254
--- /dev/null
+++ b/data/custom_dataset_data_loader.py
@@ -0,0 +1,31 @@
+import torch.utils.data
+from data.base_data_loader import BaseDataLoader
+
+
+def CreateDataset(opt):
+ dataset = None
+ from data.aligned_dataset import AlignedDataset
+ dataset = AlignedDataset()
+
+ print("dataset [%s] was created" % (dataset.name()))
+ dataset.initialize(opt)
+ return dataset
+
+class CustomDatasetDataLoader(BaseDataLoader):
+ def name(self):
+ return 'CustomDatasetDataLoader'
+
+ def initialize(self, opt):
+ BaseDataLoader.initialize(self, opt)
+ self.dataset = CreateDataset(opt)
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batchSize,
+ shuffle=not opt.serial_batches,
+ num_workers=int(opt.nThreads))
+
+ def load_data(self):
+ return self.dataloader
+
+ def __len__(self):
+ return min(len(self.dataset), self.opt.max_dataset_size)
diff --git a/data/data_loader.py b/data/data_loader.py
new file mode 100755
index 0000000..2a4433a
--- /dev/null
+++ b/data/data_loader.py
@@ -0,0 +1,7 @@
+
+def CreateDataLoader(opt):
+ from data.custom_dataset_data_loader import CustomDatasetDataLoader
+ data_loader = CustomDatasetDataLoader()
+ print(data_loader.name())
+ data_loader.initialize(opt)
+ return data_loader
diff --git a/data/image_folder.py b/data/image_folder.py
new file mode 100755
index 0000000..16a447c
--- /dev/null
+++ b/data/image_folder.py
@@ -0,0 +1,68 @@
+###############################################################################
+# Code from
+# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
+# Modified the original code so that it also loads images from the current
+# directory as well as the subdirectories
+###############################################################################
+
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir):
+ images = []
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
+
+ for root, _, fnames in sorted(os.walk(dir)):
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+
+ return images
+
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+
+class ImageFolder(data.Dataset):
+
+ def __init__(self, root, transform=None, return_paths=False,
+ loader=default_loader):
+ imgs = make_dataset(root)
+ if len(imgs) == 0:
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
+ "Supported image extensions are: " +
+ ",".join(IMG_EXTENSIONS)))
+
+ self.root = root
+ self.imgs = imgs
+ self.transform = transform
+ self.return_paths = return_paths
+ self.loader = loader
+
+ def __getitem__(self, index):
+ path = self.imgs[index]
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.return_paths:
+ return img, path
+ else:
+ return img
+
+ def __len__(self):
+ return len(self.imgs)
diff --git a/datasets/cityscapes/test_inst/frankfurt_000000_000576_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000000_000576_gtFine_instanceIds.png
new file mode 100755
index 0000000..01da7ed
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000000_000576_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000000_001236_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000000_001236_gtFine_instanceIds.png
new file mode 100755
index 0000000..75506bc
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000000_001236_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000000_003357_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000000_003357_gtFine_instanceIds.png
new file mode 100755
index 0000000..9bd27b0
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000000_003357_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000000_011810_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000000_011810_gtFine_instanceIds.png
new file mode 100755
index 0000000..df84eee
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000000_011810_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000000_012868_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000000_012868_gtFine_instanceIds.png
new file mode 100755
index 0000000..ba1f7aa
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000000_012868_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_013710_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_013710_gtFine_instanceIds.png
new file mode 100755
index 0000000..d05b7db
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000001_013710_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_015328_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_015328_gtFine_instanceIds.png
new file mode 100755
index 0000000..32d62a3
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000001_015328_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_023769_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_023769_gtFine_instanceIds.png
new file mode 100755
index 0000000..9eef682
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000001_023769_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_028335_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_028335_gtFine_instanceIds.png
new file mode 100755
index 0000000..b1909d5
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000001_028335_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_032711_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_032711_gtFine_instanceIds.png
new file mode 100755
index 0000000..ac2e293
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000001_032711_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_033655_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_033655_gtFine_instanceIds.png
new file mode 100755
index 0000000..de7328e
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000001_033655_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_042733_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_042733_gtFine_instanceIds.png
new file mode 100755
index 0000000..a98d096
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000001_042733_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_047552_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_047552_gtFine_instanceIds.png
new file mode 100755
index 0000000..ab569e3
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000001_047552_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_054640_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_054640_gtFine_instanceIds.png
new file mode 100755
index 0000000..5f246a6
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000001_054640_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_inst/frankfurt_000001_055387_gtFine_instanceIds.png b/datasets/cityscapes/test_inst/frankfurt_000001_055387_gtFine_instanceIds.png
new file mode 100755
index 0000000..2e7d01f
--- /dev/null
+++ b/datasets/cityscapes/test_inst/frankfurt_000001_055387_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000000_000576_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000000_000576_gtFine_labelIds.png
new file mode 100755
index 0000000..8c9464c
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000000_000576_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000000_001236_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000000_001236_gtFine_labelIds.png
new file mode 100755
index 0000000..9f0ca9f
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000000_001236_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000000_003357_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000000_003357_gtFine_labelIds.png
new file mode 100755
index 0000000..1035e55
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000000_003357_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000000_011810_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000000_011810_gtFine_labelIds.png
new file mode 100755
index 0000000..a86913b
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000000_011810_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000000_012868_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000000_012868_gtFine_labelIds.png
new file mode 100755
index 0000000..fe81c83
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000000_012868_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000001_013710_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_013710_gtFine_labelIds.png
new file mode 100755
index 0000000..72b4be4
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000001_013710_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000001_015328_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_015328_gtFine_labelIds.png
new file mode 100755
index 0000000..afefb6b
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000001_015328_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000001_023769_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_023769_gtFine_labelIds.png
new file mode 100755
index 0000000..f3af9df
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000001_023769_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000001_028335_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_028335_gtFine_labelIds.png
new file mode 100755
index 0000000..5e65e3e
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000001_028335_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000001_032711_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_032711_gtFine_labelIds.png
new file mode 100755
index 0000000..ba07b73
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000001_032711_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000001_033655_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_033655_gtFine_labelIds.png
new file mode 100755
index 0000000..77f519c
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000001_033655_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000001_042733_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_042733_gtFine_labelIds.png
new file mode 100755
index 0000000..ba08f1d
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000001_042733_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000001_047552_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_047552_gtFine_labelIds.png
new file mode 100755
index 0000000..5dff09a
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000001_047552_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000001_054640_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_054640_gtFine_labelIds.png
new file mode 100755
index 0000000..cb2ab2b
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000001_054640_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/test_label/frankfurt_000001_055387_gtFine_labelIds.png b/datasets/cityscapes/test_label/frankfurt_000001_055387_gtFine_labelIds.png
new file mode 100755
index 0000000..b00ef7e
--- /dev/null
+++ b/datasets/cityscapes/test_label/frankfurt_000001_055387_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/train_img/aachen_000000_000019_leftImg8bit.png b/datasets/cityscapes/train_img/aachen_000000_000019_leftImg8bit.png
new file mode 100755
index 0000000..0e6867e
--- /dev/null
+++ b/datasets/cityscapes/train_img/aachen_000000_000019_leftImg8bit.png
Binary files differ
diff --git a/datasets/cityscapes/train_img/aachen_000001_000019_leftImg8bit.png b/datasets/cityscapes/train_img/aachen_000001_000019_leftImg8bit.png
new file mode 100755
index 0000000..d5a96ce
--- /dev/null
+++ b/datasets/cityscapes/train_img/aachen_000001_000019_leftImg8bit.png
Binary files differ
diff --git a/datasets/cityscapes/train_img/aachen_000002_000019_leftImg8bit.png b/datasets/cityscapes/train_img/aachen_000002_000019_leftImg8bit.png
new file mode 100755
index 0000000..10ce563
--- /dev/null
+++ b/datasets/cityscapes/train_img/aachen_000002_000019_leftImg8bit.png
Binary files differ
diff --git a/datasets/cityscapes/train_img/aachen_000003_000019_leftImg8bit.png b/datasets/cityscapes/train_img/aachen_000003_000019_leftImg8bit.png
new file mode 100755
index 0000000..3027fe1
--- /dev/null
+++ b/datasets/cityscapes/train_img/aachen_000003_000019_leftImg8bit.png
Binary files differ
diff --git a/datasets/cityscapes/train_img/aachen_000004_000019_leftImg8bit.png b/datasets/cityscapes/train_img/aachen_000004_000019_leftImg8bit.png
new file mode 100755
index 0000000..26945fc
--- /dev/null
+++ b/datasets/cityscapes/train_img/aachen_000004_000019_leftImg8bit.png
Binary files differ
diff --git a/datasets/cityscapes/train_inst/aachen_000000_000019_gtFine_instanceIds.png b/datasets/cityscapes/train_inst/aachen_000000_000019_gtFine_instanceIds.png
new file mode 100755
index 0000000..f4ee222
--- /dev/null
+++ b/datasets/cityscapes/train_inst/aachen_000000_000019_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/train_inst/aachen_000001_000019_gtFine_instanceIds.png b/datasets/cityscapes/train_inst/aachen_000001_000019_gtFine_instanceIds.png
new file mode 100755
index 0000000..dd69137
--- /dev/null
+++ b/datasets/cityscapes/train_inst/aachen_000001_000019_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/train_inst/aachen_000002_000019_gtFine_instanceIds.png b/datasets/cityscapes/train_inst/aachen_000002_000019_gtFine_instanceIds.png
new file mode 100755
index 0000000..bdad5e3
--- /dev/null
+++ b/datasets/cityscapes/train_inst/aachen_000002_000019_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/train_inst/aachen_000003_000019_gtFine_instanceIds.png b/datasets/cityscapes/train_inst/aachen_000003_000019_gtFine_instanceIds.png
new file mode 100755
index 0000000..91a035b
--- /dev/null
+++ b/datasets/cityscapes/train_inst/aachen_000003_000019_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/train_inst/aachen_000004_000019_gtFine_instanceIds.png b/datasets/cityscapes/train_inst/aachen_000004_000019_gtFine_instanceIds.png
new file mode 100755
index 0000000..0f5fc70
--- /dev/null
+++ b/datasets/cityscapes/train_inst/aachen_000004_000019_gtFine_instanceIds.png
Binary files differ
diff --git a/datasets/cityscapes/train_label/aachen_000000_000019_gtFine_labelIds.png b/datasets/cityscapes/train_label/aachen_000000_000019_gtFine_labelIds.png
new file mode 100755
index 0000000..eed7ee6
--- /dev/null
+++ b/datasets/cityscapes/train_label/aachen_000000_000019_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/train_label/aachen_000001_000019_gtFine_labelIds.png b/datasets/cityscapes/train_label/aachen_000001_000019_gtFine_labelIds.png
new file mode 100755
index 0000000..e9c25ee
--- /dev/null
+++ b/datasets/cityscapes/train_label/aachen_000001_000019_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/train_label/aachen_000002_000019_gtFine_labelIds.png b/datasets/cityscapes/train_label/aachen_000002_000019_gtFine_labelIds.png
new file mode 100755
index 0000000..c96ab17
--- /dev/null
+++ b/datasets/cityscapes/train_label/aachen_000002_000019_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/train_label/aachen_000003_000019_gtFine_labelIds.png b/datasets/cityscapes/train_label/aachen_000003_000019_gtFine_labelIds.png
new file mode 100755
index 0000000..da05594
--- /dev/null
+++ b/datasets/cityscapes/train_label/aachen_000003_000019_gtFine_labelIds.png
Binary files differ
diff --git a/datasets/cityscapes/train_label/aachen_000004_000019_gtFine_labelIds.png b/datasets/cityscapes/train_label/aachen_000004_000019_gtFine_labelIds.png
new file mode 100755
index 0000000..bb30bd9
--- /dev/null
+++ b/datasets/cityscapes/train_label/aachen_000004_000019_gtFine_labelIds.png
Binary files differ
diff --git a/encode_features.py b/encode_features.py
new file mode 100755
index 0000000..0e97da8
--- /dev/null
+++ b/encode_features.py
@@ -0,0 +1,57 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+from options.train_options import TrainOptions
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+import numpy as np
+import os, time
+import util.util as util
+from torch.autograd import Variable
+
+opt = TrainOptions().parse()
+opt.nThreads = 1
+opt.batchSize = 1
+opt.serial_batches = True
+opt.no_flip = True
+opt.instance_feat = True
+
+name = 'features'
+save_path = os.path.join(opt.checkpoints_dir, opt.name)
+
+############ Initialize #########
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+dataset_size = len(data_loader)
+model = create_model(opt)
+
+########### Encode features ###########
+reencode = True
+if reencode:
+ features = {}
+ for label in range(opt.label_nc):
+ features[label] = np.zeros((0, opt.feat_num+1))
+ for i, data in enumerate(dataset):
+ feat = model.module.encode_features(data['image'], data['inst'])
+ for label in range(opt.label_nc):
+ features[label] = np.append(features[label], feat[label], axis=0)
+
+ print('%d / %d images' % (i+1, dataset_size))
+ save_name = os.path.join(save_path, name + '.npy')
+ np.save(save_name, features)
+
+############## Clustering ###########
+n_clusters = opt.n_clusters
+load_name = os.path.join(save_path, name + '.npy')
+features = np.load(load_name).item()
+from sklearn.cluster import KMeans
+centers = {}
+for label in range(opt.label_nc):
+ feat = features[label]
+ feat = feat[feat[:,-1] > 0.5, :-1]
+ if feat.shape[0]:
+ n_clusters = min(feat.shape[0], opt.n_clusters)
+ kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat)
+ centers[label] = kmeans.cluster_centers_
+save_name = os.path.join(save_path, name + '_clustered_%03d.npy' % opt.n_clusters)
+np.save(save_name, centers)
+print('saving to %s' % save_name) \ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100755
index 0000000..e69de29
--- /dev/null
+++ b/models/__init__.py
diff --git a/models/base_model.py b/models/base_model.py
new file mode 100755
index 0000000..d3879d0
--- /dev/null
+++ b/models/base_model.py
@@ -0,0 +1,86 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import os
+import torch
+
+class BaseModel(torch.nn.Module):
+ def name(self):
+ return 'BaseModel'
+
+ def initialize(self, opt):
+ self.opt = opt
+ self.gpu_ids = opt.gpu_ids
+ self.isTrain = opt.isTrain
+ self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
+
+ def set_input(self, input):
+ self.input = input
+
+ def forward(self):
+ pass
+
+ # used in test time, no backprop
+ def test(self):
+ pass
+
+ def get_image_paths(self):
+ pass
+
+ def optimize_parameters(self):
+ pass
+
+ def get_current_visuals(self):
+ return self.input
+
+ def get_current_errors(self):
+ return {}
+
+ def save(self, label):
+ pass
+
+ # helper saving function that can be used by subclasses
+ def save_network(self, network, network_label, epoch_label, gpu_ids):
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
+ save_path = os.path.join(self.save_dir, save_filename)
+ torch.save(network.cpu().state_dict(), save_path)
+ if len(gpu_ids) and torch.cuda.is_available():
+ network.cuda()
+
+ # helper loading function that can be used by subclasses
+ def load_network(self, network, network_label, epoch_label, save_dir=''):
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
+ if not save_dir:
+ save_dir = self.save_dir
+ save_path = os.path.join(save_dir, save_filename)
+ if not os.path.isfile(save_path):
+ print('%s not exists yet!' % save_path)
+ if network_label == 'G':
+ raise('Generator must exist!')
+ else:
+ #network.load_state_dict(torch.load(save_path))
+ try:
+ network.load_state_dict(torch.load(save_path))
+ except:
+ pretrained_dict = torch.load(save_path)
+ model_dict = network.state_dict()
+ try:
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
+ network.load_state_dict(pretrained_dict)
+ print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
+ except:
+ print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
+ from sets import Set
+ not_initialized = Set()
+ for k, v in pretrained_dict.items():
+ if v.size() == model_dict[k].size():
+ model_dict[k] = v
+
+ for k, v in model_dict.items():
+ if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
+ not_initialized.add(k.split('.')[0])
+ print(sorted(not_initialized))
+ network.load_state_dict(model_dict)
+
+ def update_learning_rate():
+ pass
diff --git a/models/models.py b/models/models.py
new file mode 100755
index 0000000..351483c
--- /dev/null
+++ b/models/models.py
@@ -0,0 +1,14 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import torch
+
+def create_model(opt):
+ from .pix2pixHD_model import Pix2PixHDModel
+ model = Pix2PixHDModel()
+ model.initialize(opt)
+ print("model [%s] was created" % (model.name()))
+
+ if opt.isTrain and len(opt.gpu_ids):
+ model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
+
+ return model
diff --git a/models/networks.py b/models/networks.py
new file mode 100755
index 0000000..a673a56
--- /dev/null
+++ b/models/networks.py
@@ -0,0 +1,421 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import torch
+import torch.nn as nn
+from torch.nn import init
+import functools
+from torch.autograd import Variable
+import numpy as np
+import math
+import torch.nn.functional as F
+import copy
+
+###############################################################################
+# Functions
+###############################################################################
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ m.weight.data.normal_(0.0, 0.02)
+ elif classname.find('BatchNorm2d') != -1:
+ m.weight.data.normal_(1.0, 0.02)
+ m.bias.data.fill_(0)
+
+def get_norm_layer(norm_type='instance'):
+ if norm_type == 'batch':
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
+ elif norm_type == 'instance':
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
+ else:
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
+ return norm_layer
+
+def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
+ n_blocks_local=3, norm='instance', gpu_ids=[]):
+ norm_layer = get_norm_layer(norm_type=norm)
+ if netG == 'global':
+ netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
+ elif netG == 'local':
+ netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
+ n_local_enhancers, n_blocks_local, norm_layer)
+ elif netG == 'encoder':
+ netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer)
+ else:
+ raise('generator not implemented!')
+ print(netG)
+ if len(gpu_ids) > 0:
+ assert(torch.cuda.is_available())
+ netG.cuda(device_id=gpu_ids[0])
+ netG.apply(weights_init)
+ return netG
+
+def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
+ norm_layer = get_norm_layer(norm_type=norm)
+ netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
+ print(netD)
+ if len(gpu_ids) > 0:
+ assert(torch.cuda.is_available())
+ netD.cuda(device_id=gpu_ids[0])
+ netD.apply(weights_init)
+ return netD
+
+def print_network(net):
+ if isinstance(net, list):
+ net = net[0]
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ print(net)
+ print('Total number of parameters: %d' % num_params)
+
+##############################################################################
+# Losses
+##############################################################################
+class GANLoss(nn.Module):
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
+ tensor=torch.FloatTensor):
+ super(GANLoss, self).__init__()
+ self.real_label = target_real_label
+ self.fake_label = target_fake_label
+ self.real_label_var = None
+ self.fake_label_var = None
+ self.Tensor = tensor
+ if use_lsgan:
+ self.loss = nn.MSELoss()
+ else:
+ self.loss = nn.BCELoss()
+
+ def get_target_tensor(self, input, target_is_real):
+ target_tensor = None
+ if target_is_real:
+ create_label = ((self.real_label_var is None) or
+ (self.real_label_var.numel() != input.numel()))
+ if create_label:
+ real_tensor = self.Tensor(input.size()).fill_(self.real_label)
+ self.real_label_var = Variable(real_tensor, requires_grad=False)
+ target_tensor = self.real_label_var
+ else:
+ create_label = ((self.fake_label_var is None) or
+ (self.fake_label_var.numel() != input.numel()))
+ if create_label:
+ fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
+ self.fake_label_var = Variable(fake_tensor, requires_grad=False)
+ target_tensor = self.fake_label_var
+ return target_tensor
+
+ def __call__(self, input, target_is_real):
+ if isinstance(input[0], list):
+ loss = 0
+ for input_i in input:
+ pred = input_i[-1]
+ target_tensor = self.get_target_tensor(pred, target_is_real)
+ loss += self.loss(pred, target_tensor)
+ return loss
+ else:
+ target_tensor = self.get_target_tensor(input[-1], target_is_real)
+ return self.loss(input[-1], target_tensor)
+
+class VGGLoss(nn.Module):
+ def __init__(self, gpu_ids):
+ super(VGGLoss, self).__init__()
+ self.vgg = Vgg19().cuda()
+ self.criterion = nn.L1Loss()
+ self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
+
+ def forward(self, x, y):
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
+ loss = 0
+ for i in range(len(x_vgg)):
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
+ return loss
+
+##############################################################################
+# Generator
+##############################################################################
+class LocalEnhancer(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9,
+ n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'):
+ super(LocalEnhancer, self).__init__()
+ self.n_local_enhancers = n_local_enhancers
+
+ ###### global generator model #####
+ ngf_global = ngf * (2**n_local_enhancers)
+ model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model
+ model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers
+ self.model = nn.Sequential(*model_global)
+
+ ###### local enhancer layers #####
+ for n in range(1, n_local_enhancers+1):
+ ### downsample
+ ngf_global = ngf * (2**(n_local_enhancers-n))
+ model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
+ norm_layer(ngf_global), nn.ReLU(True),
+ nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1),
+ norm_layer(ngf_global * 2), nn.ReLU(True)]
+ ### residual blocks
+ model_upsample = []
+ for i in range(n_blocks_local):
+ model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)]
+
+ ### upsample
+ model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
+ norm_layer(ngf_global), nn.ReLU(True)]
+
+ ### final convolution
+ if n == n_local_enhancers:
+ model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
+
+ setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample))
+ setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample))
+
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
+
+ def forward(self, input):
+ ### create input pyramid
+ input_downsampled = [input]
+ for i in range(self.n_local_enhancers):
+ input_downsampled.append(self.downsample(input_downsampled[-1]))
+
+ ### output at coarest level
+ output_prev = self.model(input_downsampled[-1])
+ ### build up one layer at a time
+ for n_local_enhancers in range(1, self.n_local_enhancers+1):
+ model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1')
+ model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2')
+ input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers]
+ output_prev = model_upsample(model_downsample(input_i) + output_prev)
+ return output_prev
+
+class GlobalGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
+ padding_type='reflect'):
+ assert(n_blocks >= 0)
+ super(GlobalGenerator, self).__init__()
+ activation = nn.ReLU(True)
+
+ model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
+ ### downsample
+ for i in range(n_downsampling):
+ mult = 2**i
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
+ norm_layer(ngf * mult * 2), activation]
+
+ ### resnet blocks
+ mult = 2**n_downsampling
+ for i in range(n_blocks):
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
+
+ ### upsample
+ for i in range(n_downsampling):
+ mult = 2**(n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
+ norm_layer(int(ngf * mult / 2)), activation]
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ return self.model(input)
+
+# Define a resnet block
+class ResnetBlock(nn.Module):
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
+ super(ResnetBlock, self).__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
+
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
+ conv_block = []
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim),
+ activation]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
+
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ out = x + self.conv_block(x)
+ return out
+
+class Encoder(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
+ super(Encoder, self).__init__()
+ self.output_nc = output_nc
+
+ model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
+ norm_layer(ngf), nn.ReLU(True)]
+ ### downsample
+ for i in range(n_downsampling):
+ mult = 2**i
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
+ norm_layer(ngf * mult * 2), nn.ReLU(True)]
+
+ ### upsample
+ for i in range(n_downsampling):
+ mult = 2**(n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
+ norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]
+
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input, inst):
+ outputs = self.model(input)
+
+ # instance-wise average pooling
+ outputs_mean = outputs.clone()
+ inst_list = np.unique(inst.cpu().numpy().astype(int))
+ for i in inst_list:
+ indices = (inst == i).nonzero() # n x 4
+ for j in range(self.output_nc):
+ output_ins = outputs[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]]
+ mean_feat = torch.mean(output_ins).expand_as(output_ins)
+ outputs_mean[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat
+ return outputs_mean
+
+class MultiscaleDiscriminator(nn.Module):
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
+ use_sigmoid=False, num_D=3, getIntermFeat=False):
+ super(MultiscaleDiscriminator, self).__init__()
+ self.num_D = num_D
+ self.n_layers = n_layers
+ self.getIntermFeat = getIntermFeat
+
+ for i in range(num_D):
+ netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
+ if getIntermFeat:
+ for j in range(n_layers+2):
+ setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
+ else:
+ setattr(self, 'layer'+str(i), netD.model)
+
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
+
+ def singleD_forward(self, model, input):
+ if self.getIntermFeat:
+ result = [input]
+ for i in range(len(model)):
+ result.append(model[i](result[-1]))
+ return result[1:]
+ else:
+ return [model(input)]
+
+ def forward(self, input):
+ num_D = self.num_D
+ result = []
+ input_downsampled = input
+ for i in range(num_D):
+ if self.getIntermFeat:
+ model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
+ else:
+ model = getattr(self, 'layer'+str(num_D-1-i))
+ result.append(self.singleD_forward(model, input_downsampled))
+ if i != (num_D-1):
+ input_downsampled = self.downsample(input_downsampled)
+ return result
+
+# Defines the PatchGAN discriminator with the specified arguments.
+class NLayerDiscriminator(nn.Module):
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
+ super(NLayerDiscriminator, self).__init__()
+ self.getIntermFeat = getIntermFeat
+ self.n_layers = n_layers
+
+ kw = 4
+ padw = int(np.ceil((kw-1.0)/2))
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
+
+ nf = ndf
+ for n in range(1, n_layers):
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+ sequence += [[
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
+ ]]
+
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+ sequence += [[
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
+ norm_layer(nf),
+ nn.LeakyReLU(0.2, True)
+ ]]
+
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
+
+ if use_sigmoid:
+ sequence += [nn.Sigmoid()]
+
+ if getIntermFeat:
+ for n in range(len(sequence)):
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
+ else:
+ sequence_stream = []
+ for n in range(len(sequence)):
+ sequence_stream += sequence[n]
+ self.model = nn.Sequential(*sequence_stream)
+
+ def forward(self, input):
+ if self.getIntermFeat:
+ res = [input]
+ for n in range(self.n_layers+2):
+ model = getattr(self, 'model'+str(n))
+ res.append(model(res[-1]))
+ return res[1:]
+ else:
+ return self.model(input)
+
+from torchvision import models
+class Vgg19(torch.nn.Module):
+ def __init__(self, requires_grad=False):
+ super(Vgg19, self).__init__()
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ for x in range(2):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(2, 7):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(7, 12):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(12, 21):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(21, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h_relu1 = self.slice1(X)
+ h_relu2 = self.slice2(h_relu1)
+ h_relu3 = self.slice3(h_relu2)
+ h_relu4 = self.slice4(h_relu3)
+ h_relu5 = self.slice5(h_relu4)
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
+ return out
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py
new file mode 100755
index 0000000..ba44e53
--- /dev/null
+++ b/models/pix2pixHD_model.py
@@ -0,0 +1,260 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import numpy as np
+import torch
+import os
+from collections import OrderedDict
+from torch.autograd import Variable
+import util.util as util
+from util.image_pool import ImagePool
+from .base_model import BaseModel
+from . import networks
+
+class Pix2PixHDModel(BaseModel):
+ def name(self):
+ return 'Pix2PixHDModel'
+
+ def initialize(self, opt):
+ BaseModel.initialize(self, opt)
+ if opt.resize_or_crop != 'none': # when training at full res this causes OOM
+ torch.backends.cudnn.benchmark = True
+ self.isTrain = opt.isTrain
+ self.use_features = opt.instance_feat or opt.label_feat
+ self.gen_features = self.use_features and not self.opt.load_features
+
+ ##### define networks
+ # Generator network
+ netG_input_nc = opt.label_nc
+ if not opt.no_instance:
+ netG_input_nc += 1
+ if self.use_features:
+ netG_input_nc += opt.feat_num
+ self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
+ opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
+ opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
+
+ # Discriminator network
+ if self.isTrain:
+ use_sigmoid = opt.no_lsgan
+ netD_input_nc = opt.label_nc + opt.output_nc
+ if not opt.no_instance:
+ netD_input_nc += 1
+ self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
+ opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
+
+ ### Encoder network
+ if self.gen_features:
+ self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
+ opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
+
+ print('---------- Networks initialized -------------')
+
+ # load networks
+ if not self.isTrain or opt.continue_train or opt.load_pretrain:
+ pretrained_path = '' if not self.isTrain else opt.load_pretrain
+ self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
+ if self.isTrain:
+ self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
+ if self.gen_features:
+ self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
+
+ # set loss functions and optimizers
+ if self.isTrain:
+ if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
+ raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
+ self.fake_pool = ImagePool(opt.pool_size)
+ self.old_lr = opt.lr
+
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
+ self.criterionFeat = torch.nn.L1Loss()
+ if not opt.no_vgg_loss:
+ self.criterionVGG = networks.VGGLoss(self.gpu_ids)
+
+ # Names so we can breakout loss
+ self.loss_names = ['G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake']
+
+ # initialize optimizers
+ # optimizer G
+ if opt.niter_fix_global > 0:
+ print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
+ params_dict = dict(self.netG.named_parameters())
+ params = []
+ for key, value in params_dict.items():
+ if key.startswith('model' + str(opt.n_local_enhancers)):
+ params += [{'params':[value],'lr':opt.lr}]
+ else:
+ params += [{'params':[value],'lr':0.0}]
+ else:
+ params = list(self.netG.parameters())
+ if self.gen_features:
+ params += list(self.netE.parameters())
+ self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
+
+ # optimizer D
+ params = list(self.netD.parameters())
+ self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
+
+ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
+ # create one-hot vector for label map
+ size = label_map.size()
+ oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
+ input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
+ input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
+
+ # get edges from instance map
+ if not self.opt.no_instance:
+ inst_map = inst_map.data.cuda()
+ edge_map = self.get_edges(inst_map)
+ input_label = torch.cat((input_label, edge_map), dim=1)
+ input_label = Variable(input_label, volatile=infer)
+
+ # real images for training
+ if real_image is not None:
+ real_image = Variable(real_image.data.cuda())
+
+ # instance map for feature encoding
+ if self.use_features:
+ # get precomputed feature maps
+ if self.opt.load_features:
+ feat_map = Variable(feat_map.data.cuda())
+
+ return input_label, inst_map, real_image, feat_map
+
+ def discriminate(self, input_label, test_image, use_pool=False):
+ input_concat = torch.cat((input_label, test_image.detach()), dim=1)
+ if use_pool:
+ fake_query = self.fake_pool.query(input_concat)
+ return self.netD.forward(fake_query)
+ else:
+ return self.netD.forward(input_concat)
+
+ def forward(self, label, inst, image, feat, infer=False):
+ # Encode Inputs
+ input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
+
+ # Fake Generation
+ if self.use_features:
+ if not self.opt.load_features:
+ feat_map = self.netE.forward(real_image, inst_map)
+ input_concat = torch.cat((input_label, feat_map), dim=1)
+ else:
+ input_concat = input_label
+ fake_image = self.netG.forward(input_concat)
+
+ # Fake Detection and Loss
+ pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
+ loss_D_fake = self.criterionGAN(pred_fake_pool, False)
+
+ # Real Detection and Loss
+ pred_real = self.discriminate(input_label, real_image)
+ loss_D_real = self.criterionGAN(pred_real, True)
+
+ # GAN loss (Fake Passability Loss)
+ pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
+ loss_G_GAN = self.criterionGAN(pred_fake, True)
+
+ # GAN feature matching loss
+ loss_G_GAN_Feat = 0
+ if not self.opt.no_ganFeat_loss:
+ feat_weights = 4.0 / (self.opt.n_layers_D + 1)
+ D_weights = 1.0 / self.opt.num_D
+ for i in range(self.opt.num_D):
+ for j in range(len(pred_fake[i])-1):
+ loss_G_GAN_Feat += D_weights * feat_weights * \
+ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
+
+ # VGG feature matching loss
+ loss_G_VGG = 0
+ if not self.opt.no_vgg_loss:
+ loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
+
+ # Only return the fake_B image if necessary to save BW
+ return [ [ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ], None if not infer else fake_image ]
+
+ def inference(self, label, inst):
+ # Encode Inputs
+ input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True)
+
+ # Fake Generation
+ if self.use_features:
+ # sample clusters from precomputed features
+ feat_map = self.sample_features(inst_map)
+ input_concat = torch.cat((input_label, feat_map), dim=1)
+ else:
+ input_concat = input_label
+ fake_image = self.netG.forward(input_concat)
+ return fake_image
+
+ def sample_features(self, inst):
+ # read precomputed feature clusters
+ cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
+ features_clustered = np.load(cluster_path).item()
+
+ # randomly sample from the feature clusters
+ inst_np = inst.cpu().numpy().astype(int)
+ feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3])
+ for i in np.unique(inst_np):
+ label = i if i < 1000 else i//1000
+ if label in features_clustered:
+ feat = features_clustered[label]
+ cluster_idx = np.random.randint(0, feat.shape[0])
+
+ idx = (inst == i).nonzero()
+ for k in range(self.opt.feat_num):
+ feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
+ return feat_map
+
+ def encode_features(self, image, inst):
+ image = Variable(image.cuda(), volatile=True)
+ feat_num = self.opt.feat_num
+ h, w = inst.size()[2], inst.size()[3]
+ block_num = 32
+ feat_map = self.netE.forward(image, inst.cuda())
+ inst_np = inst.cpu().numpy().astype(int)
+ feature = {}
+ for i in range(self.opt.label_nc):
+ feature[i] = np.zeros((0, feat_num+1))
+ for i in np.unique(inst_np):
+ label = i if i < 1000 else i//1000
+ idx = (inst == i).nonzero()
+ num = idx.size()[0]
+ idx = idx[num//2,:]
+ val = np.zeros((1, feat_num+1))
+ for k in range(feat_num):
+ val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
+ val[0, feat_num] = float(num) / (h * w // block_num)
+ feature[label] = np.append(feature[label], val, axis=0)
+ return feature
+
+ def get_edges(self, t):
+ edge = torch.cuda.ByteTensor(t.size()).zero_()
+ edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
+ edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
+ edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
+ edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
+ return edge.float()
+
+ def save(self, which_epoch):
+ self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
+ self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
+ if self.gen_features:
+ self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
+
+ def update_fixed_params(self):
+ # after fixing the global generator for a number of iterations, also start finetuning it
+ params = list(self.netG.parameters())
+ if self.gen_features:
+ params += list(self.netE.parameters())
+ self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
+ print('------------ Now also finetuning global generator -----------')
+
+ def update_learning_rate(self):
+ lrd = self.opt.lr / self.opt.niter_decay
+ lr = self.old_lr - lrd
+ for param_group in self.optimizer_D.param_groups:
+ param_group['lr'] = lr
+ for param_group in self.optimizer_G.param_groups:
+ param_group['lr'] = lr
+ print('update learning rate: %f -> %f' % (self.old_lr, lr))
+ self.old_lr = lr
diff --git a/options/__init__.py b/options/__init__.py
new file mode 100755
index 0000000..e69de29
--- /dev/null
+++ b/options/__init__.py
diff --git a/options/base_options.py b/options/base_options.py
new file mode 100755
index 0000000..863c061
--- /dev/null
+++ b/options/base_options.py
@@ -0,0 +1,95 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import argparse
+import os
+from util import util
+import torch
+
+class BaseOptions():
+ def __init__(self):
+ self.parser = argparse.ArgumentParser()
+ self.initialized = False
+
+ def initialize(self):
+ # experiment specifics
+ self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models')
+ self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
+ self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
+ self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
+ self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
+
+ # input/output sizes
+ self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
+ self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size')
+ self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
+ self.parser.add_argument('--label_nc', type=int, default=35, help='# of input image channels')
+ self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
+
+ # for setting inputs
+ self.parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/')
+ self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
+ self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+ self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
+ self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
+ self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+
+ # for displays
+ self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
+ self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
+
+ # for generator
+ self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
+ self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
+ self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG')
+ self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network')
+ self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
+ self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')
+ self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer')
+
+ # for instance-wise features
+ self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input')
+ self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input')
+ self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input')
+ self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features')
+ self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps')
+ self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder')
+ self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
+ self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features')
+
+ self.initialized = True
+
+ def parse(self, save=True):
+ if not self.initialized:
+ self.initialize()
+ self.opt = self.parser.parse_args()
+ self.opt.isTrain = self.isTrain # train or test
+
+ str_ids = self.opt.gpu_ids.split(',')
+ self.opt.gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ self.opt.gpu_ids.append(id)
+
+ # set gpu ids
+ if len(self.opt.gpu_ids) > 0:
+ torch.cuda.set_device(self.opt.gpu_ids[0])
+
+ args = vars(self.opt)
+
+ print('------------ Options -------------')
+ for k, v in sorted(args.items()):
+ print('%s: %s' % (str(k), str(v)))
+ print('-------------- End ----------------')
+
+ # save to the disk
+ expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
+ util.mkdirs(expr_dir)
+ if save and not self.opt.continue_train:
+ file_name = os.path.join(expr_dir, 'opt.txt')
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write('------------ Options -------------\n')
+ for k, v in sorted(args.items()):
+ opt_file.write('%s: %s\n' % (str(k), str(v)))
+ opt_file.write('-------------- End ----------------\n')
+ return self.opt
diff --git a/options/test_options.py b/options/test_options.py
new file mode 100755
index 0000000..aaeff53
--- /dev/null
+++ b/options/test_options.py
@@ -0,0 +1,15 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+from .base_options import BaseOptions
+
+class TestOptions(BaseOptions):
+ def initialize(self):
+ BaseOptions.initialize(self)
+ self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
+ self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
+ self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
+ self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
+ self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
+ self.isTrain = False
diff --git a/options/train_options.py b/options/train_options.py
new file mode 100755
index 0000000..9994a4a
--- /dev/null
+++ b/options/train_options.py
@@ -0,0 +1,36 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+from .base_options import BaseOptions
+
+class TrainOptions(BaseOptions):
+ def initialize(self):
+ BaseOptions.initialize(self)
+ # for displays
+ self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
+ self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
+ self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
+ self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
+ self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
+ self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
+
+ # for training
+ self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
+ self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
+ self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
+ self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
+ self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
+ self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
+ self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
+
+ # for discriminators
+ self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
+ self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
+ self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
+ self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
+ self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
+ self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
+ self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
+ self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
+
+ self.isTrain = True
diff --git a/precompute_feature_maps.py b/precompute_feature_maps.py
new file mode 100755
index 0000000..a631b9c
--- /dev/null
+++ b/precompute_feature_maps.py
@@ -0,0 +1,36 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+from options.train_options import TrainOptions
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+import numpy as np
+import os, time
+import util.util as util
+from torch.autograd import Variable
+import torch.nn as nn
+
+opt = TrainOptions().parse()
+opt.nThreads = 1
+opt.batchSize = 1
+opt.serial_batches = True
+opt.no_flip = True
+opt.instance_feat = True
+
+name = 'features'
+save_path = os.path.join(opt.checkpoints_dir, opt.name)
+
+############ Initialize #########
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+dataset_size = len(data_loader)
+model = create_model(opt)
+util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat'))
+
+######## Save precomputed feature maps for 1024p training #######
+for i, data in enumerate(dataset):
+ print('%d / %d images' % (i+1, dataset_size))
+ feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda())
+ feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map)
+ image_numpy = util.tensor2im(feat_map.data[0])
+ save_path = data['path'][0].replace('/train_label/', '/train_feat/')
+ util.save_image(image_numpy, save_path) \ No newline at end of file
diff --git a/scripts/test_1024p.sh b/scripts/test_1024p.sh
new file mode 100755
index 0000000..99c1e24
--- /dev/null
+++ b/scripts/test_1024p.sh
@@ -0,0 +1,3 @@
+################################ Testing ################################
+# labels only
+python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none \ No newline at end of file
diff --git a/scripts/test_1024p_feat.sh b/scripts/test_1024p_feat.sh
new file mode 100755
index 0000000..2f4ba17
--- /dev/null
+++ b/scripts/test_1024p_feat.sh
@@ -0,0 +1,5 @@
+################################ Testing ################################
+# first precompute and cluster all features
+python encode_features.py --name label2city_1024p_feat --netG local --ngf 32 --resize_or_crop none;
+# use instance-wise features
+python test.py --name label2city_1024p_feat ---netG local --ngf 32 --resize_or_crop none --instance_feat \ No newline at end of file
diff --git a/scripts/test_512p.sh b/scripts/test_512p.sh
new file mode 100755
index 0000000..3131043
--- /dev/null
+++ b/scripts/test_512p.sh
@@ -0,0 +1,3 @@
+################################ Testing ################################
+# labels only
+python test.py --name label2city_512p \ No newline at end of file
diff --git a/scripts/test_512p_feat.sh b/scripts/test_512p_feat.sh
new file mode 100755
index 0000000..8f25e9c
--- /dev/null
+++ b/scripts/test_512p_feat.sh
@@ -0,0 +1,5 @@
+################################ Testing ################################
+# first precompute and cluster all features
+python encode_features.py --name label2city_512p_feat;
+# use instance-wise features
+python test.py --name label2city_512p_feat --instance_feat \ No newline at end of file
diff --git a/scripts/train_1024p_12G.sh b/scripts/train_1024p_12G.sh
new file mode 100755
index 0000000..d5ea7d7
--- /dev/null
+++ b/scripts/train_1024p_12G.sh
@@ -0,0 +1,4 @@
+############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
+##### Using GPUs with 12G memory (not tested)
+# Using labels only
+python train.py --name label2city_1024p --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p/ --niter_fix_global 20 --resize_or_crop crop --fineSize 1024 \ No newline at end of file
diff --git a/scripts/train_1024p_24G.sh b/scripts/train_1024p_24G.sh
new file mode 100755
index 0000000..88e58f7
--- /dev/null
+++ b/scripts/train_1024p_24G.sh
@@ -0,0 +1,4 @@
+############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
+######## Using GPUs with 24G memory
+# Using labels only
+python train.py --name label2city_1024p --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p/ --niter 50 --niter_decay 50 --niter_fix_global 10 --resize_or_crop none \ No newline at end of file
diff --git a/scripts/train_1024p_feat_12G.sh b/scripts/train_1024p_feat_12G.sh
new file mode 100755
index 0000000..f8e3d61
--- /dev/null
+++ b/scripts/train_1024p_feat_12G.sh
@@ -0,0 +1,6 @@
+############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
+##### Using GPUs with 12G memory (not tested)
+# First precompute feature maps and save them
+python precompute_feature_maps.py --name label2city_512p_feat;
+# Adding instances and encoded features
+python train.py --name label2city_1024p_feat --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p_feat/ --niter_fix_global 20 --resize_or_crop crop --fineSize 896 --instance_feat --load_features \ No newline at end of file
diff --git a/scripts/train_1024p_feat_24G.sh b/scripts/train_1024p_feat_24G.sh
new file mode 100755
index 0000000..399d720
--- /dev/null
+++ b/scripts/train_1024p_feat_24G.sh
@@ -0,0 +1,6 @@
+############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
+######## Using GPUs with 24G memory
+# First precompute feature maps and save them
+python precompute_feature_maps.py --name label2city_512p_feat;
+# Adding instances and encoded features
+python train.py --name label2city_1024p_feat --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p_feat/ --niter 50 --niter_decay 50 --niter_fix_global 10 --resize_or_crop none --instance_feat --load_features \ No newline at end of file
diff --git a/scripts/train_512p.sh b/scripts/train_512p.sh
new file mode 100755
index 0000000..222c348
--- /dev/null
+++ b/scripts/train_512p.sh
@@ -0,0 +1,2 @@
+### Using labels only
+python train.py --name label2city_512p \ No newline at end of file
diff --git a/scripts/train_512p_feat.sh b/scripts/train_512p_feat.sh
new file mode 100755
index 0000000..9d4859c
--- /dev/null
+++ b/scripts/train_512p_feat.sh
@@ -0,0 +1,2 @@
+### Adding instances and encoded features
+python train.py --name label2city_512p_feat --instance_feat \ No newline at end of file
diff --git a/scripts/train_512p_multigpu.sh b/scripts/train_512p_multigpu.sh
new file mode 100755
index 0000000..16f0a1a
--- /dev/null
+++ b/scripts/train_512p_multigpu.sh
@@ -0,0 +1,2 @@
+######## Multi-GPU training example #######
+python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7 \ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100755
index 0000000..d96ac10
--- /dev/null
+++ b/test.py
@@ -0,0 +1,37 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import time
+import os
+from collections import OrderedDict
+from options.test_options import TestOptions
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+import util.util as util
+from util.visualizer import Visualizer
+from util import html
+
+opt = TestOptions().parse(save=False)
+opt.nThreads = 1 # test code only supports nThreads = 1
+opt.batchSize = 1 # test code only supports batchSize = 1
+opt.serial_batches = True # no shuffle
+opt.no_flip = True # no flip
+
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+model = create_model(opt)
+visualizer = Visualizer(opt)
+# create website
+web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
+webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
+# test
+for i, data in enumerate(dataset):
+ if i >= opt.how_many:
+ break
+ generated = model.inference(data['label'], data['inst'])
+ visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
+ ('synthesized_image', util.tensor2im(generated.data[0]))])
+ img_path = data['path']
+ print('process image... %s' % img_path)
+ visualizer.save_images(webpage, visuals, img_path)
+
+webpage.save()
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..4965481
--- /dev/null
+++ b/train.py
@@ -0,0 +1,118 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import time
+from collections import OrderedDict
+from options.train_options import TrainOptions
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+import util.util as util
+from util.visualizer import Visualizer
+import os
+import numpy as np
+import torch
+from torch.autograd import Variable
+
+opt = TrainOptions().parse()
+iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
+if opt.continue_train:
+ try:
+ start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
+ except:
+ start_epoch, epoch_iter = 1, 0
+ print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
+else:
+ start_epoch, epoch_iter = 1, 0
+
+if opt.debug:
+ opt.display_freq = 1
+ opt.print_freq = 1
+ opt.niter = 1
+ opt.niter_decay = 0
+ opt.max_dataset_size = 10
+
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+dataset_size = len(data_loader)
+print('#training images = %d' % dataset_size)
+
+model = create_model(opt)
+visualizer = Visualizer(opt)
+
+total_steps = (start_epoch-1) * dataset_size + epoch_iter
+for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
+ epoch_start_time = time.time()
+ if epoch != start_epoch:
+ epoch_iter = epoch_iter % dataset_size
+ for i, data in enumerate(dataset, start=epoch_iter):
+ iter_start_time = time.time()
+ total_steps += opt.batchSize
+ epoch_iter += opt.batchSize
+
+ # whether to collect output images
+ save_fake = total_steps % opt.display_freq == 0
+
+ ############## Forward Pass ######################
+ losses, generated = model(Variable(data['label']), Variable(data['inst']),
+ Variable(data['image']), Variable(data['feat']), infer=save_fake)
+
+ # sum per device losses
+ losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
+ loss_dict = dict(zip(model.module.loss_names, losses))
+
+ # calculate final loss scalar
+ loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
+ loss_G = loss_dict['G_GAN'] + loss_dict['G_GAN_Feat'] + loss_dict['G_VGG']
+
+ ############### Backward Pass ####################
+ # update generator weights
+ model.module.optimizer_G.zero_grad()
+ loss_G.backward()
+ model.module.optimizer_G.step()
+
+ # update discriminator weights
+ model.module.optimizer_D.zero_grad()
+ loss_D.backward()
+ model.module.optimizer_D.step()
+
+ #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])
+
+ ############## Display results and errors ##########
+ ### print out errors
+ if total_steps % opt.print_freq == 0:
+ errors = {k: v.data[0] if not isinstance(v, (int,long,float)) else v for k, v in loss_dict.items()}
+ t = (time.time() - iter_start_time) / opt.batchSize
+ visualizer.print_current_errors(epoch, epoch_iter, errors, t)
+ visualizer.plot_current_errors(errors, total_steps)
+
+ ### display output images
+ if save_fake:
+ visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
+ ('synthesized_image', util.tensor2im(generated.data[0])),
+ ('real_image', util.tensor2im(data['image'][0]))])
+ visualizer.display_current_results(visuals, epoch, total_steps)
+
+ ### save latest model
+ if total_steps % opt.save_latest_freq == 0:
+ print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
+ model.module.save('latest')
+ np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
+
+ # end of epoch
+ iter_end_time = time.time()
+ print('End of epoch %d / %d \t Time Taken: %d sec' %
+ (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
+
+ ### save model for this epoch
+ if epoch % opt.save_epoch_freq == 0:
+ print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
+ model.module.save('latest')
+ model.module.save(epoch)
+ np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
+
+ ### instead of only training the local enhancer, train the entire network after certain iterations
+ if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
+ model.module.update_fixed_params()
+
+ ### linearly decay learning rate after certain iterations
+ if epoch > opt.niter:
+ model.module.update_learning_rate()
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100755
index 0000000..e69de29
--- /dev/null
+++ b/util/__init__.py
diff --git a/util/html.py b/util/html.py
new file mode 100755
index 0000000..a80aa59
--- /dev/null
+++ b/util/html.py
@@ -0,0 +1,63 @@
+import dominate
+from dominate.tags import *
+import os
+
+
+class HTML:
+ def __init__(self, web_dir, title, reflesh=0):
+ self.title = title
+ self.web_dir = web_dir
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ if not os.path.exists(self.web_dir):
+ os.makedirs(self.web_dir)
+ if not os.path.exists(self.img_dir):
+ os.makedirs(self.img_dir)
+
+ self.doc = dominate.document(title=title)
+ if reflesh > 0:
+ with self.doc.head:
+ meta(http_equiv="reflesh", content=str(reflesh))
+
+ def get_image_dir(self):
+ return self.img_dir
+
+ def add_header(self, str):
+ with self.doc:
+ h3(str)
+
+ def add_table(self, border=1):
+ self.t = table(border=border, style="table-layout: fixed;")
+ self.doc.add(self.t)
+
+ def add_images(self, ims, txts, links, width=512):
+ self.add_table()
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % (width), src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__':
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims = []
+ txts = []
+ links = []
+ for n in range(4):
+ ims.append('image_%d.jpg' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.jpg' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/util/image_pool.py b/util/image_pool.py
new file mode 100755
index 0000000..152ef5b
--- /dev/null
+++ b/util/image_pool.py
@@ -0,0 +1,32 @@
+import random
+import numpy as np
+import torch
+from torch.autograd import Variable
+class ImagePool():
+ def __init__(self, pool_size):
+ self.pool_size = pool_size
+ if self.pool_size > 0:
+ self.num_imgs = 0
+ self.images = []
+
+ def query(self, images):
+ if self.pool_size == 0:
+ return images
+ return_images = []
+ for image in images.data:
+ image = torch.unsqueeze(image, 0)
+ if self.num_imgs < self.pool_size:
+ self.num_imgs = self.num_imgs + 1
+ self.images.append(image)
+ return_images.append(image)
+ else:
+ p = random.uniform(0, 1)
+ if p > 0.5:
+ random_id = random.randint(0, self.pool_size-1)
+ tmp = self.images[random_id].clone()
+ self.images[random_id] = image
+ return_images.append(tmp)
+ else:
+ return_images.append(image)
+ return_images = Variable(torch.cat(return_images, 0))
+ return return_images
diff --git a/util/util.py b/util/util.py
new file mode 100755
index 0000000..0898f7a
--- /dev/null
+++ b/util/util.py
@@ -0,0 +1,99 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+from __future__ import print_function
+import torch
+import numpy as np
+from PIL import Image
+import inspect, re
+import numpy as np
+import os
+import collections
+from PIL import Image
+
+# Converts a Tensor into a Numpy array
+# |imtype|: the desired type of the converted numpy array
+def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
+ if isinstance(image_tensor, list):
+ image_numpy = []
+ for i in range(len(image_tensor)):
+ image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
+ return image_numpy
+ image_numpy = image_tensor.cpu().float().numpy()
+ if normalize:
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
+ else:
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
+ image_numpy = np.clip(image_numpy, 0, 255)
+ if image_numpy.shape[2] == 1:
+ image_numpy = image_numpy[:,:,0]
+ return image_numpy.astype(imtype)
+
+def tensor2label(output, n_label, imtype=np.uint8):
+ output = output.cpu().float()
+ if output.size()[0] > 1:
+ output = output.max(0, keepdim=True)[1]
+ output = Colorize(n_label)(output)
+ output = np.transpose(output.numpy(), (1, 2, 0))
+ return output.astype(imtype)
+
+def save_image(image_numpy, image_path):
+ image_pil = Image.fromarray(image_numpy)
+ image_pil.save(image_path)
+
+def mkdirs(paths):
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+def uint82bin(n, count=8):
+ """returns the binary of integer n, count refers to amount of bits"""
+ return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
+
+def labelcolormap(N):
+ if N == 35: # cityscape
+ cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
+ (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
+ (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
+ (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
+ ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
+ dtype=np.uint8)
+ else:
+ cmap = np.zeros((N, 3), dtype=np.uint8)
+ for i in range(N):
+ r = 0
+ g = 0
+ b = 0
+ id = i
+ for j in range(7):
+ str_id = uint82bin(id)
+ r = r ^ (np.uint8(str_id[-1]) << (7-j))
+ g = g ^ (np.uint8(str_id[-2]) << (7-j))
+ b = b ^ (np.uint8(str_id[-3]) << (7-j))
+ id = id >> 3
+ cmap[i, 0] = r
+ cmap[i, 1] = g
+ cmap[i, 2] = b
+ return cmap
+
+class Colorize(object):
+ def __init__(self, n=35):
+ self.cmap = labelcolormap(n)
+ self.cmap = torch.from_numpy(self.cmap[:n])
+
+ def __call__(self, gray_image):
+ size = gray_image.size()
+ color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
+
+ for label in range(0, len(self.cmap)):
+ mask = (label == gray_image[0]).cpu()
+ color_image[0][mask] = self.cmap[label][0]
+ color_image[1][mask] = self.cmap[label][1]
+ color_image[2][mask] = self.cmap[label][2]
+
+ return color_image
diff --git a/util/visualizer.py b/util/visualizer.py
new file mode 100755
index 0000000..f41c55a
--- /dev/null
+++ b/util/visualizer.py
@@ -0,0 +1,133 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import numpy as np
+import os
+import ntpath
+import time
+from . import util
+from . import html
+import scipy.misc
+try:
+ from StringIO import StringIO # Python 2.7
+except ImportError:
+ from io import BytesIO # Python 3.x
+
+class Visualizer():
+ def __init__(self, opt):
+ # self.opt = opt
+ self.tf_log = opt.tf_log
+ self.use_html = opt.isTrain and not opt.no_html
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ if self.tf_log:
+ import tensorflow as tf
+ self.tf = tf
+ self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
+ self.writer = tf.summary.FileWriter(self.log_dir)
+
+ if self.use_html:
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+ # |visuals|: dictionary of images to display or save
+ def display_current_results(self, visuals, epoch, step):
+ if self.tf_log: # show images in tensorboard output
+ img_summaries = []
+ for label, image_numpy in visuals.items():
+ # Write the image to a string
+ try:
+ s = StringIO()
+ except:
+ s = BytesIO()
+ scipy.misc.toimage(image_numpy).save(s, format="jpeg")
+ # Create an Image object
+ img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])
+ # Create a Summary value
+ img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))
+
+ # Create and write Summary
+ summary = self.tf.Summary(value=img_summaries)
+ self.writer.add_summary(summary, step)
+
+ if self.use_html: # save images to a html file
+ for label, image_numpy in visuals.items():
+ if isinstance(image_numpy, list):
+ for i in range(len(image_numpy)):
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i))
+ util.save_image(image_numpy[i], img_path)
+ else:
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label))
+ util.save_image(image_numpy, img_path)
+
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims = []
+ txts = []
+ links = []
+
+ for label, image_numpy in visuals.items():
+ if isinstance(image_numpy, list):
+ for i in range(len(image_numpy)):
+ img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i)
+ ims.append(img_path)
+ txts.append(label+str(i))
+ links.append(img_path)
+ else:
+ img_path = 'epoch%.3d_%s.jpg' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ if len(ims) < 10:
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ else:
+ num = int(round(len(ims)/2.0))
+ webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
+ webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
+ webpage.save()
+
+ # errors: dictionary of error labels and values
+ def plot_current_errors(self, errors, step):
+ if self.tf_log:
+ for tag, value in errors.items():
+ summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
+ self.writer.add_summary(summary, step)
+
+ # errors: same format as |errors| of plotCurrentErrors
+ def print_current_errors(self, epoch, i, errors, t):
+ message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
+ for k, v in errors.items():
+ if v != 0:
+ message += '%s: %.3f ' % (k, v)
+
+ print(message)
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message)
+
+ # save image to the disk
+ def save_images(self, webpage, visuals, image_path):
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims = []
+ txts = []
+ links = []
+
+ for label, image_numpy in visuals.items():
+ image_name = '%s_%s.jpg' % (name, label)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(image_numpy, save_path)
+
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ webpage.add_images(ims, txts, links, width=self.win_size)