diff options
| -rwxr-xr-x | models/base_model.py | 8 | ||||
| -rwxr-xr-x | models/models.py | 3 | ||||
| -rwxr-xr-x | models/pix2pixHD_model.py | 28 | ||||
| -rwxr-xr-x | options/base_options.py | 3 | ||||
| -rwxr-xr-x | options/test_options.py | 4 | ||||
| -rwxr-xr-x | scripts/test_1024p.sh | 2 | ||||
| -rwxr-xr-x | test.py | 14 |
7 files changed, 47 insertions, 15 deletions
diff --git a/models/base_model.py b/models/base_model.py index 88e0587..2cda12f 100755 --- a/models/base_model.py +++ b/models/base_model.py @@ -68,7 +68,8 @@ class BaseModel(torch.nn.Module): try: pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} network.load_state_dict(pretrained_dict) - print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) + if self.opt.verbose: + print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) except: print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) if sys.version_info >= (3,0): @@ -82,8 +83,9 @@ class BaseModel(torch.nn.Module): for k, v in model_dict.items(): if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): - not_initialized.add(k.split('.')[0]) - print(sorted(not_initialized)) + not_initialized.add(k.split('.')[0]) + if self.opt.verbose: + print(sorted(not_initialized)) network.load_state_dict(model_dict) def update_learning_rate(): diff --git a/models/models.py b/models/models.py index 0ba442f..805696f 100755 --- a/models/models.py +++ b/models/models.py @@ -10,7 +10,8 @@ def create_model(opt): from .ui_model import UIModel model = UIModel() model.initialize(opt) - print("model [%s] was created" % (model.name())) + if opt.verbose: + print("model [%s] was created" % (model.name())) if opt.isTrain and len(opt.gpu_ids): model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py index 834fc18..de594ab 100755 --- a/models/pix2pixHD_model.py +++ b/models/pix2pixHD_model.py @@ -45,8 +45,8 @@ class Pix2PixHDModel(BaseModel): if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) - - print('---------- Networks initialized -------------') + if self.opt.verbose: + print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: @@ -76,7 +76,8 @@ class Pix2PixHDModel(BaseModel): # initialize optimizers # optimizer G if opt.niter_fix_global > 0: - print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) + if self.opt.verbose: + print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): @@ -103,13 +104,15 @@ class Pix2PixHDModel(BaseModel): oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + if self.opt.data_type==16: + input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) - input_label = Variable(input_label, volatile=infer) + input_label = Variable(input_label, requires_grad = not infer) # real images for training if real_image is not None: @@ -204,7 +207,9 @@ class Pix2PixHDModel(BaseModel): idx = (inst == i).nonzero() for k in range(self.opt.feat_num): - feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] + feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] + if self.opt.data_type==16: + feat_map = feat_map.half() return feat_map def encode_features(self, image, inst): @@ -235,7 +240,10 @@ class Pix2PixHDModel(BaseModel): edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) - return edge.float() + if self.opt.data_type==16: + return edge.half() + else: + return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) @@ -248,8 +256,9 @@ class Pix2PixHDModel(BaseModel): params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) - self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) - print('------------ Now also finetuning global generator -----------') + self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) + if self.opt.verbose: + print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay @@ -258,5 +267,6 @@ class Pix2PixHDModel(BaseModel): param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr - print('update learning rate: %f -> %f' % (self.old_lr, lr)) + if self.opt.verbose: + print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr diff --git a/options/base_options.py b/options/base_options.py index de831fe..561a890 100755 --- a/options/base_options.py +++ b/options/base_options.py @@ -56,7 +56,8 @@ class BaseOptions(): self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder') self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features') - + self.parser.add_argument('--verbose', action='store_true', default = False, help='toggles verbose') + self.initialized = True def parse(self, save=True): diff --git a/options/test_options.py b/options/test_options.py index aaeff53..504edf3 100755 --- a/options/test_options.py +++ b/options/test_options.py @@ -12,4 +12,8 @@ class TestOptions(BaseOptions): self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') + self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") + self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") + self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") + self.parser.add_argument("-d", "--data_type", default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit") self.isTrain = False diff --git a/scripts/test_1024p.sh b/scripts/test_1024p.sh index 99c1e24..7526f28 100755 --- a/scripts/test_1024p.sh +++ b/scripts/test_1024p.sh @@ -1,3 +1,3 @@ ################################ Testing ################################
# labels only
-python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none
\ No newline at end of file +python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none $@
@@ -26,6 +26,20 @@ webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.na for i, data in enumerate(dataset): if i >= opt.how_many: break + if opt.data_type == 16: + model.half() + data['label'] = data['label'].half() + data['inst'] = data['inst'].half() + elif opt.data_type == 8: + model.type(torch.uint8) + + if opt.export_onnx: + assert opt.export_onnx.endswith(".onnx"), "Export model file should end with .onnx" + if opt.verbose: + print(model) + generated = torch.onnx.export(model, [data['label'], data['inst']], + opt.export_onnx, verbose=True) + generated = model.inference(data['label'], data['inst']) visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), ('synthesized_image', util.tensor2im(generated.data[0]))]) |
