diff options
| -rwxr-xr-x | models/base_model.py | 8 | ||||
| -rwxr-xr-x | models/models.py | 10 | ||||
| -rwxr-xr-x | models/pix2pixHD_model.py | 35 | ||||
| -rwxr-xr-x | options/base_options.py | 3 | ||||
| -rwxr-xr-x | options/test_options.py | 4 | ||||
| -rw-r--r-- | run_engine.py | 173 | ||||
| -rwxr-xr-x | scripts/test_1024p.sh | 7 | ||||
| -rwxr-xr-x | test.py | 36 |
8 files changed, 255 insertions, 21 deletions
diff --git a/models/base_model.py b/models/base_model.py index 88e0587..2cda12f 100755 --- a/models/base_model.py +++ b/models/base_model.py @@ -68,7 +68,8 @@ class BaseModel(torch.nn.Module): try: pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} network.load_state_dict(pretrained_dict) - print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) + if self.opt.verbose: + print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) except: print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) if sys.version_info >= (3,0): @@ -82,8 +83,9 @@ class BaseModel(torch.nn.Module): for k, v in model_dict.items(): if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): - not_initialized.add(k.split('.')[0]) - print(sorted(not_initialized)) + not_initialized.add(k.split('.')[0]) + if self.opt.verbose: + print(sorted(not_initialized)) network.load_state_dict(model_dict) def update_learning_rate(): diff --git a/models/models.py b/models/models.py index 0ba442f..8e72e46 100755 --- a/models/models.py +++ b/models/models.py @@ -4,13 +4,17 @@ import torch def create_model(opt): if opt.model == 'pix2pixHD': - from .pix2pixHD_model import Pix2PixHDModel - model = Pix2PixHDModel() + from .pix2pixHD_model import Pix2PixHDModel, InferenceModel + if opt.isTrain: + model = Pix2PixHDModel() + else: + model = InferenceModel() else: from .ui_model import UIModel model = UIModel() model.initialize(opt) - print("model [%s] was created" % (model.name())) + if opt.verbose: + print("model [%s] was created" % (model.name())) if opt.isTrain and len(opt.gpu_ids): model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py index b77868a..79ebabd 100755 --- a/models/pix2pixHD_model.py +++ b/models/pix2pixHD_model.py @@ -50,8 +50,8 @@ class Pix2PixHDModel(BaseModel): if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) - - print('---------- Networks initialized -------------') + if self.opt.verbose: + print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: @@ -84,7 +84,8 @@ class Pix2PixHDModel(BaseModel): # initialize optimizers # optimizer G if opt.niter_fix_global > 0: - print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) + if self.opt.verbose: + print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): @@ -111,13 +112,15 @@ class Pix2PixHDModel(BaseModel): oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + if self.opt.data_type==16: + input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) - input_label = Variable(input_label, volatile=infer) + input_label = Variable(input_label, requires_grad = not infer) # real images for training if real_image is not None: @@ -212,7 +215,9 @@ class Pix2PixHDModel(BaseModel): idx = (inst == i).nonzero() for k in range(self.opt.feat_num): - feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] + feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] + if self.opt.data_type==16: + feat_map = feat_map.half() return feat_map def encode_features(self, image, inst): @@ -243,7 +248,10 @@ class Pix2PixHDModel(BaseModel): edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) - return edge.float() + if self.opt.data_type==16: + return edge.half() + else: + return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) @@ -256,8 +264,9 @@ class Pix2PixHDModel(BaseModel): params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) - self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) - print('------------ Now also finetuning global generator -----------') + self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) + if self.opt.verbose: + print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay @@ -266,5 +275,13 @@ class Pix2PixHDModel(BaseModel): param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr - print('update learning rate: %f -> %f' % (self.old_lr, lr)) + if self.opt.verbose: + print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr + +class InferenceModel(Pix2PixHDModel): + def forward(self, inp): + label, inst = inp + return self.inference(label, inst) + + diff --git a/options/base_options.py b/options/base_options.py index de831fe..561a890 100755 --- a/options/base_options.py +++ b/options/base_options.py @@ -56,7 +56,8 @@ class BaseOptions(): self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder') self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features') - + self.parser.add_argument('--verbose', action='store_true', default = False, help='toggles verbose') + self.initialized = True def parse(self, save=True): diff --git a/options/test_options.py b/options/test_options.py index aaeff53..504edf3 100755 --- a/options/test_options.py +++ b/options/test_options.py @@ -12,4 +12,8 @@ class TestOptions(BaseOptions): self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') + self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") + self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") + self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") + self.parser.add_argument("-d", "--data_type", default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit") self.isTrain = False diff --git a/run_engine.py b/run_engine.py new file mode 100644 index 0000000..700494d --- /dev/null +++ b/run_engine.py @@ -0,0 +1,173 @@ +import os +import sys +from random import randint +import numpy as np +import tensorrt + +try: + from PIL import Image + import pycuda.driver as cuda + import pycuda.gpuarray as gpuarray + import pycuda.autoinit + import argparse +except ImportError as err: + sys.stderr.write("""ERROR: failed to import module ({}) +Please make sure you have pycuda and the example dependencies installed. +https://wiki.tiker.net/PyCuda/Installation/Linux +pip(3) install tensorrt[examples] +""".format(err)) + exit(1) + +try: + import tensorrt as trt + from tensorrt.parsers import caffeparser + from tensorrt.parsers import onnxparser +except ImportError as err: + sys.stderr.write("""ERROR: failed to import module ({}) +Please make sure you have the TensorRT Library installed +and accessible in your LD_LIBRARY_PATH +""".format(err)) + exit(1) + + +G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO) + +class Profiler(trt.infer.Profiler): + """ + Example Implimentation of a Profiler + Is identical to the Profiler class in trt.infer so it is possible + to just use that instead of implementing this if further + functionality is not needed + """ + def __init__(self, timing_iter): + trt.infer.Profiler.__init__(self) + self.timing_iterations = timing_iter + self.profile = [] + + def report_layer_time(self, layerName, ms): + record = next((r for r in self.profile if r[0] == layerName), (None, None)) + if record == (None, None): + self.profile.append((layerName, ms)) + else: + self.profile[self.profile.index(record)] = (record[0], record[1] + ms) + + def print_layer_times(self): + totalTime = 0 + for i in range(len(self.profile)): + print("{:40.40} {:4.3f}ms".format(self.profile[i][0], self.profile[i][1] / self.timing_iterations)) + totalTime += self.profile[i][1] + print("Time over all layers: {:4.2f} ms per iteration".format(totalTime / self.timing_iterations)) + + +def get_input_output_names(trt_engine): + nbindings = trt_engine.get_nb_bindings(); + maps = [] + + for b in range(0, nbindings): + dims = trt_engine.get_binding_dimensions(b).to_DimsCHW() + name = trt_engine.get_binding_name(b) + type = trt_engine.get_binding_data_type(b) + + if (trt_engine.binding_is_input(b)): + maps.append(name) + print("Found input: ", name) + else: + maps.append(name) + print("Found output: ", name) + + print("shape=" + str(dims.C()) + " , " + str(dims.H()) + " , " + str(dims.W())) + print("dtype=" + str(type)) + return maps + +def create_memory(engine, name, buf, mem, batchsize, inp, inp_idx): + binding_idx = engine.get_binding_index(name) + if binding_idx == -1: + raise AttributeError("Not a valid binding") + print("Binding: name={}, bindingIndex={}".format(name, str(binding_idx))) + dims = engine.get_binding_dimensions(binding_idx).to_DimsCHW() + eltCount = dims.C() * dims.H() * dims.W() * batchsize + + if engine.binding_is_input(binding_idx): + h_mem = inp[inp_idx] + inp_idx = inp_idx + 1 + else: + h_mem = np.random.uniform(0.0, 255.0, eltCount).astype(np.dtype('f4')) + + d_mem = cuda.mem_alloc(eltCount * 4) + cuda.memcpy_htod(d_mem, h_mem) + buf.insert(binding_idx, int(d_mem)) + mem.append(d_mem) + return inp_idx + + +#Run inference on device +def time_inference(engine, batch_size, inp): + bindings = [] + mem = [] + inp_idx = 0 + for io in get_input_output_names(engine): + inp_idx = create_memory(engine, io, bindings, mem, + batch_size, inp, inp_idx) + + context = engine.create_execution_context() + g_prof = Profiler(500) + context.set_profiler(g_prof) + for i in range(iter): + context.execute(batch_size, bindings) + g_prof.print_layer_times() + + context.destroy() + return + + +def convert_to_datatype(v): + if v==8: + return trt.infer.DataType.INT8 + elif v==16: + return trt.infer.DataType.HALF + elif v==32: + return trt.infer.DataType.FLOAT + else: + print("ERROR: Invalid model data type bit depth: " + str(v)) + return trt.infer.DataType.INT8 + +def run_trt_engine(engine_file, bs, it): + engine = trt.utils.load_engine(G_LOGGER, engine_file) + time_inference(engine, bs, it) + +def run_onnx(onnx_file, data_type, bs, inp): + # Create onnx_config + apex = onnxparser.create_onnxconfig() + apex.set_model_file_name(onnx_file) + apex.set_model_dtype(convert_to_datatype(data_type)) + + # create parser + trt_parser = onnxparser.create_onnxparser(apex) + assert(trt_parser) + data_type = apex.get_model_dtype() + onnx_filename = apex.get_model_file_name() + trt_parser.parse(onnx_filename, data_type) + trt_parser.report_parsing_info() + trt_parser.convert_to_trtnetwork() + trt_network = trt_parser.get_trtnetwork() + assert(trt_network) + + # create infer builder + trt_builder = trt.infer.create_infer_builder(G_LOGGER) + trt_builder.set_max_batch_size(max_batch_size) + trt_builder.set_max_workspace_size(max_workspace_size) + + if (apex.get_model_dtype() == trt.infer.DataType_kHALF): + print("------------------- Running FP16 -----------------------------") + trt_builder.set_half2_mode(True) + elif (apex.get_model_dtype() == trt.infer.DataType_kINT8): + print("------------------- Running INT8 -----------------------------") + trt_builder.set_int8_mode(True) + else: + print("------------------- Running FP32 -----------------------------") + + print("----- Builder is Done -----") + print("----- Creating Engine -----") + trt_engine = trt_builder.build_cuda_engine(trt_network) + print("----- Engine is built -----") + time_inference(engine, bs, inp) diff --git a/scripts/test_1024p.sh b/scripts/test_1024p.sh index 99c1e24..319803c 100755 --- a/scripts/test_1024p.sh +++ b/scripts/test_1024p.sh @@ -1,3 +1,4 @@ -################################ Testing ################################
-# labels only
-python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none
\ No newline at end of file +#!/bin/bash +################################ Testing ################################ +# labels only +python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none $@ @@ -8,6 +8,8 @@ from models.models import create_model import util.util as util from util.visualizer import Visualizer from util import html +import torch +from run_engine import run_trt_engine, run_onnx opt = TestOptions().parse(save=False) opt.nThreads = 1 # test code only supports nThreads = 1 @@ -17,16 +19,46 @@ opt.no_flip = True # no flip data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() -model = create_model(opt) visualizer = Visualizer(opt) # create website web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) # test + +if not opt.engine and not opt.onnx: + model = create_model(opt) + if opt.data_type == 16: + model.half() + elif opt.data_type == 8: + model.type(torch.uint8) + + if opt.verbose: + print(model) + + for i, data in enumerate(dataset): if i >= opt.how_many: break - generated = model.inference(data['label'], data['inst']) + if opt.data_type == 16: + data['label'] = data['label'].half() + data['inst'] = data['inst'].half() + elif opt.data_type == 8: + data['label'] = data['label'].uint8() + data['inst'] = data['inst'].uint8() + if opt.export_onnx: + print ("Exporting to ONNX: ", opt.export_onnx) + assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" + torch.onnx.export(model, [data['label'], data['inst']], + opt.export_onnx, verbose=True) + exit(0) + minibatch = 1 + if opt.engine: + generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) + elif opt.onnx: + generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) + else: + generated = model.inference(data['label'], data['inst']) + visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), ('synthesized_image', util.tensor2im(generated.data[0]))]) img_path = data['path'] |
