summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xREADME.md2
-rwxr-xr-xdata/aligned_dataset.py54
-rwxr-xr-xmodels/models.py8
-rwxr-xr-xmodels/ui_model.py349
-rwxr-xr-xoptions/base_options.py1
5 files changed, 385 insertions, 29 deletions
diff --git a/README.md b/README.md
index c6ff340..80b425a 100755
--- a/README.md
+++ b/README.md
@@ -109,7 +109,7 @@ If only GPUs with 12G memory are available, please use the 12G script (`bash ./s
### Training with your own dataset
- If you want to train with your own dataset, please generate label maps which are one-channel whose pixel values correspond to the object labels (i.e. 0,1,...,N-1, where N is the number of labels). This is because we need to generate one-hot vectors from the label maps. Please also specity `--label_nc N` during both training and testing.
-- If your input is not a label map, please just specify `--label_nc 0` which will directly use the RGB colors as input.
+- If your input is not a label map, please just specify `--label_nc 0` which will directly use the RGB colors as input. The folders should then be named `train_A`, `train_B` instead of `train_label`, `train_img`, where the goal is to translate images from A to B.
- If you don't have instance maps or don't want to use them, please specify `--no_instance`.
- The default setting for preprocessing is `scale_width`, which will scale the width of all training images to `opt.loadSize` (1024) while keeping the aspect ratio. If you want a different setting, please change it by using the `--resize_or_crop` option. For example, `scale_width_and_crop` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `crop` skips the resizing step and only performs random cropping. If you don't want any preprocessing, please specify `none`, which will do nothing other than making sure the image is divisible by 32.
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
index a3cdc76..41468d2 100755
--- a/data/aligned_dataset.py
+++ b/data/aligned_dataset.py
@@ -10,14 +10,16 @@ class AlignedDataset(BaseDataset):
self.opt = opt
self.root = opt.dataroot
- ### label maps
- self.dir_label = os.path.join(opt.dataroot, opt.phase + '_label')
- self.label_paths = sorted(make_dataset(self.dir_label))
+ ### input A (label maps)
+ dir_A = '_A' if self.opt.label_nc == 0 else '_label'
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
+ self.A_paths = sorted(make_dataset(self.dir_A))
- ### real images
+ ### input B (real images)
if opt.isTrain:
- self.dir_image = os.path.join(opt.dataroot, opt.phase + '_img')
- self.image_paths = sorted(make_dataset(self.dir_image))
+ dir_B = '_B' if self.opt.label_nc == 0 else '_img'
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
+ self.B_paths = sorted(make_dataset(self.dir_B))
### instance maps
if not opt.no_instance:
@@ -30,47 +32,47 @@ class AlignedDataset(BaseDataset):
print('----------- loading features from %s ----------' % self.dir_feat)
self.feat_paths = sorted(make_dataset(self.dir_feat))
- self.dataset_size = len(self.label_paths)
+ self.dataset_size = len(self.A_paths)
def __getitem__(self, index):
- ### label maps
- label_path = self.label_paths[index]
- label = Image.open(label_path)
- params = get_params(self.opt, label.size)
+ ### input A (label maps)
+ A_path = self.A_paths[index]
+ A = Image.open(A_path)
+ params = get_params(self.opt, A.size)
if self.opt.label_nc == 0:
- transform_label = get_transform(self.opt, params)
- label_tensor = transform_label(label.convert('RGB'))
+ transform_A = get_transform(self.opt, params)
+ A_tensor = transform_A(A.convert('RGB'))
else:
- transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
- label_tensor = transform_label(label) * 255.0
+ transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
+ A_tensor = transform_A(A) * 255.0
- image_tensor = inst_tensor = feat_tensor = 0
- ### real images
+ B_tensor = inst_tensor = feat_tensor = 0
+ ### input B (real images)
if self.opt.isTrain:
- image_path = self.image_paths[index]
- image = Image.open(image_path).convert('RGB')
- transform_image = get_transform(self.opt, params)
- image_tensor = transform_image(image)
+ B_path = self.B_paths[index]
+ B = Image.open(B_path).convert('RGB')
+ transform_B = get_transform(self.opt, params)
+ B_tensor = transform_B(B)
### if using instance maps
if not self.opt.no_instance:
inst_path = self.inst_paths[index]
inst = Image.open(inst_path)
- inst_tensor = transform_label(inst)
+ inst_tensor = transform_A(inst)
if self.opt.load_features:
feat_path = self.feat_paths[index]
feat = Image.open(feat_path).convert('RGB')
norm = normalize()
- feat_tensor = norm(transform_label(feat))
+ feat_tensor = norm(transform_A(feat))
- input_dict = {'label': label_tensor, 'inst': inst_tensor, 'image': image_tensor,
- 'feat': feat_tensor, 'path': label_path}
+ input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
+ 'feat': feat_tensor, 'path': A_path}
return input_dict
def __len__(self):
- return len(self.label_paths)
+ return len(self.A_paths)
def name(self):
return 'AlignedDataset' \ No newline at end of file
diff --git a/models/models.py b/models/models.py
index 351483c..0ba442f 100755
--- a/models/models.py
+++ b/models/models.py
@@ -3,8 +3,12 @@
import torch
def create_model(opt):
- from .pix2pixHD_model import Pix2PixHDModel
- model = Pix2PixHDModel()
+ if opt.model == 'pix2pixHD':
+ from .pix2pixHD_model import Pix2PixHDModel
+ model = Pix2PixHDModel()
+ else:
+ from .ui_model import UIModel
+ model = UIModel()
model.initialize(opt)
print("model [%s] was created" % (model.name()))
diff --git a/models/ui_model.py b/models/ui_model.py
new file mode 100755
index 0000000..056a335
--- /dev/null
+++ b/models/ui_model.py
@@ -0,0 +1,349 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import torch
+from torch.autograd import Variable
+from collections import OrderedDict
+import numpy as np
+import os
+from PIL import Image
+import util.util as util
+from .base_model import BaseModel
+from . import networks
+
+class UIModel(BaseModel):
+ def name(self):
+ return 'UIModel'
+
+ def initialize(self, opt):
+ assert(not opt.isTrain)
+ BaseModel.initialize(self, opt)
+ self.use_features = opt.instance_feat or opt.label_feat
+
+ netG_input_nc = opt.label_nc
+ if not opt.no_instance:
+ netG_input_nc += 1
+ if self.use_features:
+ netG_input_nc += opt.feat_num
+
+ self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
+ opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
+ opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
+ self.load_network(self.netG, 'G', opt.which_epoch)
+
+ print('---------- Networks initialized -------------')
+
+ def toTensor(self, img, normalize=False):
+ tensor = torch.from_numpy(np.array(img, np.int32, copy=False))
+ tensor = tensor.view(1, img.size[1], img.size[0], len(img.mode))
+ tensor = tensor.transpose(1, 2).transpose(1, 3).contiguous()
+ if normalize:
+ return (tensor.float()/255.0 - 0.5) / 0.5
+ return tensor.float()
+
+ def load_image(self, label_path, inst_path, feat_path):
+ opt = self.opt
+ # read label map
+ label_img = Image.open(label_path)
+ if label_path.find('face') != -1:
+ label_img = label_img.convert('L')
+ ow, oh = label_img.size
+ w = opt.loadSize
+ h = int(w * oh / ow)
+ label_img = label_img.resize((w, h), Image.NEAREST)
+ label_map = self.toTensor(label_img)
+
+ # onehot vector input for label map
+ self.label_map = label_map.cuda()
+ oneHot_size = (1, opt.label_nc, h, w)
+ input_label = self.Tensor(torch.Size(oneHot_size)).zero_()
+ self.input_label = input_label.scatter_(1, label_map.long().cuda(), 1.0)
+
+ # read instance map
+ if not opt.no_instance:
+ inst_img = Image.open(inst_path)
+ inst_img = inst_img.resize((w, h), Image.NEAREST)
+ self.inst_map = self.toTensor(inst_img).cuda()
+ self.edge_map = self.get_edges(self.inst_map)
+ self.net_input = Variable(torch.cat((self.input_label, self.edge_map), dim=1), volatile=True)
+ else:
+ self.net_input = Variable(self.input_label, volatile=True)
+
+ self.features_clustered = np.load(feat_path).item()
+ self.object_map = self.inst_map if opt.instance_feat else self.label_map
+
+ object_np = self.object_map.cpu().numpy().astype(int)
+ self.feat_map = self.Tensor(1, opt.feat_num, h, w).zero_()
+ self.cluster_indices = np.zeros(self.opt.label_nc, np.uint8)
+ for i in np.unique(object_np):
+ label = i if i < 1000 else i//1000
+ if label in self.features_clustered:
+ feat = self.features_clustered[label]
+ np.random.seed(i+1)
+ cluster_idx = np.random.randint(0, feat.shape[0])
+ self.cluster_indices[label] = cluster_idx
+ idx = (self.object_map == i).nonzero()
+ self.set_features(idx, feat, cluster_idx)
+
+ self.net_input_original = self.net_input.clone()
+ self.label_map_original = self.label_map.clone()
+ self.feat_map_original = self.feat_map.clone()
+ if not opt.no_instance:
+ self.inst_map_original = self.inst_map.clone()
+
+ def reset(self):
+ self.net_input = self.net_input_prev = self.net_input_original.clone()
+ self.label_map = self.label_map_prev = self.label_map_original.clone()
+ self.feat_map = self.feat_map_prev = self.feat_map_original.clone()
+ if not self.opt.no_instance:
+ self.inst_map = self.inst_map_prev = self.inst_map_original.clone()
+ self.object_map = self.inst_map if self.opt.instance_feat else self.label_map
+
+ def undo(self):
+ self.net_input = self.net_input_prev
+ self.label_map = self.label_map_prev
+ self.feat_map = self.feat_map_prev
+ if not self.opt.no_instance:
+ self.inst_map = self.inst_map_prev
+ self.object_map = self.inst_map if self.opt.instance_feat else self.label_map
+
+ # get boundary map from instance map
+ def get_edges(self, t):
+ edge = torch.cuda.ByteTensor(t.size()).zero_()
+ edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
+ edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
+ edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
+ edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
+ return edge.float()
+
+ # change the label at the source position to the label at the target position
+ def change_labels(self, click_src, click_tgt):
+ y_src, x_src = click_src[0], click_src[1]
+ y_tgt, x_tgt = click_tgt[0], click_tgt[1]
+ label_src = int(self.label_map[0, 0, y_src, x_src])
+ inst_src = self.inst_map[0, 0, y_src, x_src]
+ label_tgt = int(self.label_map[0, 0, y_tgt, x_tgt])
+ inst_tgt = self.inst_map[0, 0, y_tgt, x_tgt]
+
+ idx_src = (self.inst_map == inst_src).nonzero()
+ # need to change 3 things: label map, instance map, and feature map
+ if idx_src.shape:
+ # backup current maps
+ self.backup_current_state()
+
+ # change both the label map and the network input
+ self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
+ self.net_input[idx_src[:,0], idx_src[:,1] + label_src, idx_src[:,2], idx_src[:,3]] = 0
+ self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
+
+ # update the instance map (and the network input)
+ if inst_tgt > 1000:
+ # if different instances have different ids, give the new object a new id
+ tgt_indices = (self.inst_map > label_tgt * 1000) & (self.inst_map < (label_tgt+1) * 1000)
+ inst_tgt = self.inst_map[tgt_indices].max() + 1
+ self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = inst_tgt
+ self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
+
+ # also copy the source features to the target position
+ idx_tgt = (self.inst_map == inst_tgt).nonzero()
+ if idx_tgt.shape:
+ self.copy_features(idx_src, idx_tgt[0,:])
+
+ self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
+
+ # add strokes of target label in the image
+ def add_strokes(self, click_src, label_tgt, bw, save):
+ # get the region of the new strokes (bw is the brush width)
+ size = self.net_input.size()
+ h, w = size[2], size[3]
+ idx_src = torch.LongTensor(bw**2, 4).fill_(0)
+ for i in range(bw):
+ idx_src[i*bw:(i+1)*bw, 2] = min(h-1, max(0, click_src[0]-bw//2 + i))
+ for j in range(bw):
+ idx_src[i*bw+j, 3] = min(w-1, max(0, click_src[1]-bw//2 + j))
+ idx_src = idx_src.cuda()
+
+ # again, need to update 3 things
+ if idx_src.shape:
+ # backup current maps
+ if save:
+ self.backup_current_state()
+
+ # update the label map (and the network input) in the stroke region
+ self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
+ for k in range(self.opt.label_nc):
+ self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0
+ self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
+
+ # update the instance map (and the network input)
+ self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
+ self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
+
+ # also update the features if available
+ if self.opt.instance_feat:
+ feat = self.features_clustered[label_tgt]
+ #np.random.seed(label_tgt+1)
+ #cluster_idx = np.random.randint(0, feat.shape[0])
+ cluster_idx = self.cluster_indices[label_tgt]
+ self.set_features(idx_src, feat, cluster_idx)
+
+ self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
+
+ # add an object to the clicked position with selected style
+ def add_objects(self, click_src, label_tgt, mask, style_id=0):
+ y, x = click_src[0], click_src[1]
+ mask = np.transpose(mask, (2, 0, 1))[np.newaxis,...]
+ idx_src = torch.from_numpy(mask).cuda().nonzero()
+ idx_src[:,2] += y
+ idx_src[:,3] += x
+
+ # backup current maps
+ self.backup_current_state()
+
+ # update label map
+ self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
+ for k in range(self.opt.label_nc):
+ self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0
+ self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
+
+ # update instance map
+ self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
+ self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
+
+ # update feature map
+ self.set_features(idx_src, self.feat, style_id)
+
+ self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
+
+ def single_forward(self, net_input, feat_map):
+ net_input = torch.cat((net_input, feat_map), dim=1)
+ fake_image = self.netG.forward(net_input)
+
+ if fake_image.size()[0] == 1:
+ return fake_image.data[0]
+ return fake_image.data
+
+
+ # generate all outputs for different styles
+ def style_forward(self, click_pt, style_id=-1):
+ if click_pt is None:
+ self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
+ self.crop = None
+ self.mask = None
+ else:
+ instToChange = int(self.object_map[0, 0, click_pt[0], click_pt[1]])
+ self.instToChange = instToChange
+ label = instToChange if instToChange < 1000 else instToChange//1000
+ self.feat = self.features_clustered[label]
+ self.fake_image = []
+ self.mask = self.object_map == instToChange
+ idx = self.mask.nonzero()
+ self.get_crop_region(idx)
+ if idx.size():
+ if style_id == -1:
+ (min_y, min_x, max_y, max_x) = self.crop
+ ### original
+ for cluster_idx in range(self.opt.multiple_output):
+ self.set_features(idx, self.feat, cluster_idx)
+ fake_image = self.single_forward(self.net_input, self.feat_map)
+ fake_image = util.tensor2im(fake_image[:,min_y:max_y,min_x:max_x])
+ self.fake_image.append(fake_image)
+ """### To speed up previewing different style results, either crop or downsample the label maps
+ if instToChange > 1000:
+ (min_y, min_x, max_y, max_x) = self.crop
+ ### crop
+ _, _, h, w = self.net_input.size()
+ offset = 512
+ y_start, x_start = max(0, min_y-offset), max(0, min_x-offset)
+ y_end, x_end = min(h, (max_y + offset)), min(w, (max_x + offset))
+ y_region = slice(y_start, y_start+(y_end-y_start)//16*16)
+ x_region = slice(x_start, x_start+(x_end-x_start)//16*16)
+ net_input = self.net_input[:,:,y_region,x_region]
+ for cluster_idx in range(self.opt.multiple_output):
+ self.set_features(idx, self.feat, cluster_idx)
+ fake_image = self.single_forward(net_input, self.feat_map[:,:,y_region,x_region])
+ fake_image = util.tensor2im(fake_image[:,min_y-y_start:max_y-y_start,min_x-x_start:max_x-x_start])
+ self.fake_image.append(fake_image)
+ else:
+ ### downsample
+ (min_y, min_x, max_y, max_x) = [crop//2 for crop in self.crop]
+ net_input = self.net_input[:,:,::2,::2]
+ size = net_input.size()
+ net_input_batch = net_input.expand(self.opt.multiple_output, size[1], size[2], size[3])
+ for cluster_idx in range(self.opt.multiple_output):
+ self.set_features(idx, self.feat, cluster_idx)
+ feat_map = self.feat_map[:,:,::2,::2]
+ if cluster_idx == 0:
+ feat_map_batch = feat_map
+ else:
+ feat_map_batch = torch.cat((feat_map_batch, feat_map), dim=0)
+ fake_image_batch = self.single_forward(net_input_batch, feat_map_batch)
+ for i in range(self.opt.multiple_output):
+ self.fake_image.append(util.tensor2im(fake_image_batch[i,:,min_y:max_y,min_x:max_x]))"""
+
+ else:
+ self.set_features(idx, self.feat, style_id)
+ self.cluster_indices[label] = style_id
+ self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
+
+ def backup_current_state(self):
+ self.net_input_prev = self.net_input.clone()
+ self.label_map_prev = self.label_map.clone()
+ self.inst_map_prev = self.inst_map.clone()
+ self.feat_map_prev = self.feat_map.clone()
+
+ # crop the ROI and get the mask of the object
+ def get_crop_region(self, idx):
+ size = self.net_input.size()
+ h, w = size[2], size[3]
+ min_y, min_x = idx[:,2].min(), idx[:,3].min()
+ max_y, max_x = idx[:,2].max(), idx[:,3].max()
+ crop_min = 128
+ if max_y - min_y < crop_min:
+ min_y = max(0, (max_y + min_y) // 2 - crop_min // 2)
+ max_y = min(h-1, min_y + crop_min)
+ if max_x - min_x < crop_min:
+ min_x = max(0, (max_x + min_x) // 2 - crop_min // 2)
+ max_x = min(w-1, min_x + crop_min)
+ self.crop = (min_y, min_x, max_y, max_x)
+ self.mask = self.mask[:,:, min_y:max_y, min_x:max_x]
+
+ # update the feature map once a new object is added or the label is changed
+ def update_features(self, cluster_idx, mask=None, click_pt=None):
+ self.feat_map_prev = self.feat_map.clone()
+ # adding a new object
+ if mask is not None:
+ y, x = click_pt[0], click_pt[1]
+ mask = np.transpose(mask, (2,0,1))[np.newaxis,...]
+ idx = torch.from_numpy(mask).cuda().nonzero()
+ idx[:,2] += y
+ idx[:,3] += x
+ # changing the label of an existing object
+ else:
+ idx = (self.object_map == self.instToChange).nonzero()
+
+ # update feature map
+ self.set_features(idx, self.feat, cluster_idx)
+
+ # set the class features to the target feature
+ def set_features(self, idx, feat, cluster_idx):
+ for k in range(self.opt.feat_num):
+ self.feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
+
+ # copy the features at the target position to the source position
+ def copy_features(self, idx_src, idx_tgt):
+ for k in range(self.opt.feat_num):
+ val = self.feat_map[idx_tgt[0], idx_tgt[1] + k, idx_tgt[2], idx_tgt[3]]
+ self.feat_map[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = val
+
+ def get_current_visuals(self, getLabel=False):
+ mask = self.mask
+ if self.mask is not None:
+ mask = np.transpose(self.mask[0].cpu().float().numpy(), (1,2,0)).astype(np.uint8)
+
+ dict_list = [('fake_image', self.fake_image), ('mask', mask)]
+
+ if getLabel: # only output label map if needed to save bandwidth
+ label = util.tensor2label(self.net_input.data[0], self.opt.label_nc)
+ dict_list += [('label', label)]
+
+ return OrderedDict(dict_list) \ No newline at end of file
diff --git a/options/base_options.py b/options/base_options.py
index 863c061..de831fe 100755
--- a/options/base_options.py
+++ b/options/base_options.py
@@ -15,6 +15,7 @@ class BaseOptions():
self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
+ self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use')
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')