summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTing-Chun Wang <tcwang0509@berkeley.edu>2018-05-30 22:39:01 -0700
committerGitHub <noreply@github.com>2018-05-30 22:39:01 -0700
commita2340c3fff9de44c8ef1fea5b90fced756fbbb18 (patch)
tree39f3c05c80a94d721ec6fed0f0da65ecbc3bc603
parent1b89cd010dce2e6edaa07d23c8edd8dfe146e0e1 (diff)
parent25e205604e7eafa83867a15cfda526461fe58455 (diff)
Merge pull request #33 from borisfom/fp16
Added data size and ONNX export options, FP16 inference is working
-rwxr-xr-xmodels/base_model.py8
-rwxr-xr-xmodels/models.py10
-rwxr-xr-xmodels/pix2pixHD_model.py35
-rwxr-xr-xoptions/base_options.py3
-rwxr-xr-xoptions/test_options.py4
-rw-r--r--run_engine.py173
-rwxr-xr-xscripts/test_1024p.sh7
-rwxr-xr-xtest.py36
8 files changed, 255 insertions, 21 deletions
diff --git a/models/base_model.py b/models/base_model.py
index 88e0587..2cda12f 100755
--- a/models/base_model.py
+++ b/models/base_model.py
@@ -68,7 +68,8 @@ class BaseModel(torch.nn.Module):
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
- print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
+ if self.opt.verbose:
+ print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
except:
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
if sys.version_info >= (3,0):
@@ -82,8 +83,9 @@ class BaseModel(torch.nn.Module):
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
- not_initialized.add(k.split('.')[0])
- print(sorted(not_initialized))
+ not_initialized.add(k.split('.')[0])
+ if self.opt.verbose:
+ print(sorted(not_initialized))
network.load_state_dict(model_dict)
def update_learning_rate():
diff --git a/models/models.py b/models/models.py
index 0ba442f..8e72e46 100755
--- a/models/models.py
+++ b/models/models.py
@@ -4,13 +4,17 @@ import torch
def create_model(opt):
if opt.model == 'pix2pixHD':
- from .pix2pixHD_model import Pix2PixHDModel
- model = Pix2PixHDModel()
+ from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
+ if opt.isTrain:
+ model = Pix2PixHDModel()
+ else:
+ model = InferenceModel()
else:
from .ui_model import UIModel
model = UIModel()
model.initialize(opt)
- print("model [%s] was created" % (model.name()))
+ if opt.verbose:
+ print("model [%s] was created" % (model.name()))
if opt.isTrain and len(opt.gpu_ids):
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py
index b77868a..79ebabd 100755
--- a/models/pix2pixHD_model.py
+++ b/models/pix2pixHD_model.py
@@ -50,8 +50,8 @@ class Pix2PixHDModel(BaseModel):
if self.gen_features:
self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
-
- print('---------- Networks initialized -------------')
+ if self.opt.verbose:
+ print('---------- Networks initialized -------------')
# load networks
if not self.isTrain or opt.continue_train or opt.load_pretrain:
@@ -84,7 +84,8 @@ class Pix2PixHDModel(BaseModel):
# initialize optimizers
# optimizer G
if opt.niter_fix_global > 0:
- print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
+ if self.opt.verbose:
+ print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
params_dict = dict(self.netG.named_parameters())
params = []
for key, value in params_dict.items():
@@ -111,13 +112,15 @@ class Pix2PixHDModel(BaseModel):
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
+ if self.opt.data_type==16:
+ input_label = input_label.half()
# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
- input_label = Variable(input_label, volatile=infer)
+ input_label = Variable(input_label, requires_grad = not infer)
# real images for training
if real_image is not None:
@@ -212,7 +215,9 @@ class Pix2PixHDModel(BaseModel):
idx = (inst == i).nonzero()
for k in range(self.opt.feat_num):
- feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
+ feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
+ if self.opt.data_type==16:
+ feat_map = feat_map.half()
return feat_map
def encode_features(self, image, inst):
@@ -243,7 +248,10 @@ class Pix2PixHDModel(BaseModel):
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
- return edge.float()
+ if self.opt.data_type==16:
+ return edge.half()
+ else:
+ return edge.float()
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
@@ -256,8 +264,9 @@ class Pix2PixHDModel(BaseModel):
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
- self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
- print('------------ Now also finetuning global generator -----------')
+ self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
+ if self.opt.verbose:
+ print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
@@ -266,5 +275,13 @@ class Pix2PixHDModel(BaseModel):
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
- print('update learning rate: %f -> %f' % (self.old_lr, lr))
+ if self.opt.verbose:
+ print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
+
+class InferenceModel(Pix2PixHDModel):
+ def forward(self, inp):
+ label, inst = inp
+ return self.inference(label, inst)
+
+
diff --git a/options/base_options.py b/options/base_options.py
index de831fe..561a890 100755
--- a/options/base_options.py
+++ b/options/base_options.py
@@ -56,7 +56,8 @@ class BaseOptions():
self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder')
self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features')
-
+ self.parser.add_argument('--verbose', action='store_true', default = False, help='toggles verbose')
+
self.initialized = True
def parse(self, save=True):
diff --git a/options/test_options.py b/options/test_options.py
index aaeff53..504edf3 100755
--- a/options/test_options.py
+++ b/options/test_options.py
@@ -12,4 +12,8 @@ class TestOptions(BaseOptions):
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
+ self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
+ self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
+ self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
+ self.parser.add_argument("-d", "--data_type", default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
self.isTrain = False
diff --git a/run_engine.py b/run_engine.py
new file mode 100644
index 0000000..700494d
--- /dev/null
+++ b/run_engine.py
@@ -0,0 +1,173 @@
+import os
+import sys
+from random import randint
+import numpy as np
+import tensorrt
+
+try:
+ from PIL import Image
+ import pycuda.driver as cuda
+ import pycuda.gpuarray as gpuarray
+ import pycuda.autoinit
+ import argparse
+except ImportError as err:
+ sys.stderr.write("""ERROR: failed to import module ({})
+Please make sure you have pycuda and the example dependencies installed.
+https://wiki.tiker.net/PyCuda/Installation/Linux
+pip(3) install tensorrt[examples]
+""".format(err))
+ exit(1)
+
+try:
+ import tensorrt as trt
+ from tensorrt.parsers import caffeparser
+ from tensorrt.parsers import onnxparser
+except ImportError as err:
+ sys.stderr.write("""ERROR: failed to import module ({})
+Please make sure you have the TensorRT Library installed
+and accessible in your LD_LIBRARY_PATH
+""".format(err))
+ exit(1)
+
+
+G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO)
+
+class Profiler(trt.infer.Profiler):
+ """
+ Example Implimentation of a Profiler
+ Is identical to the Profiler class in trt.infer so it is possible
+ to just use that instead of implementing this if further
+ functionality is not needed
+ """
+ def __init__(self, timing_iter):
+ trt.infer.Profiler.__init__(self)
+ self.timing_iterations = timing_iter
+ self.profile = []
+
+ def report_layer_time(self, layerName, ms):
+ record = next((r for r in self.profile if r[0] == layerName), (None, None))
+ if record == (None, None):
+ self.profile.append((layerName, ms))
+ else:
+ self.profile[self.profile.index(record)] = (record[0], record[1] + ms)
+
+ def print_layer_times(self):
+ totalTime = 0
+ for i in range(len(self.profile)):
+ print("{:40.40} {:4.3f}ms".format(self.profile[i][0], self.profile[i][1] / self.timing_iterations))
+ totalTime += self.profile[i][1]
+ print("Time over all layers: {:4.2f} ms per iteration".format(totalTime / self.timing_iterations))
+
+
+def get_input_output_names(trt_engine):
+ nbindings = trt_engine.get_nb_bindings();
+ maps = []
+
+ for b in range(0, nbindings):
+ dims = trt_engine.get_binding_dimensions(b).to_DimsCHW()
+ name = trt_engine.get_binding_name(b)
+ type = trt_engine.get_binding_data_type(b)
+
+ if (trt_engine.binding_is_input(b)):
+ maps.append(name)
+ print("Found input: ", name)
+ else:
+ maps.append(name)
+ print("Found output: ", name)
+
+ print("shape=" + str(dims.C()) + " , " + str(dims.H()) + " , " + str(dims.W()))
+ print("dtype=" + str(type))
+ return maps
+
+def create_memory(engine, name, buf, mem, batchsize, inp, inp_idx):
+ binding_idx = engine.get_binding_index(name)
+ if binding_idx == -1:
+ raise AttributeError("Not a valid binding")
+ print("Binding: name={}, bindingIndex={}".format(name, str(binding_idx)))
+ dims = engine.get_binding_dimensions(binding_idx).to_DimsCHW()
+ eltCount = dims.C() * dims.H() * dims.W() * batchsize
+
+ if engine.binding_is_input(binding_idx):
+ h_mem = inp[inp_idx]
+ inp_idx = inp_idx + 1
+ else:
+ h_mem = np.random.uniform(0.0, 255.0, eltCount).astype(np.dtype('f4'))
+
+ d_mem = cuda.mem_alloc(eltCount * 4)
+ cuda.memcpy_htod(d_mem, h_mem)
+ buf.insert(binding_idx, int(d_mem))
+ mem.append(d_mem)
+ return inp_idx
+
+
+#Run inference on device
+def time_inference(engine, batch_size, inp):
+ bindings = []
+ mem = []
+ inp_idx = 0
+ for io in get_input_output_names(engine):
+ inp_idx = create_memory(engine, io, bindings, mem,
+ batch_size, inp, inp_idx)
+
+ context = engine.create_execution_context()
+ g_prof = Profiler(500)
+ context.set_profiler(g_prof)
+ for i in range(iter):
+ context.execute(batch_size, bindings)
+ g_prof.print_layer_times()
+
+ context.destroy()
+ return
+
+
+def convert_to_datatype(v):
+ if v==8:
+ return trt.infer.DataType.INT8
+ elif v==16:
+ return trt.infer.DataType.HALF
+ elif v==32:
+ return trt.infer.DataType.FLOAT
+ else:
+ print("ERROR: Invalid model data type bit depth: " + str(v))
+ return trt.infer.DataType.INT8
+
+def run_trt_engine(engine_file, bs, it):
+ engine = trt.utils.load_engine(G_LOGGER, engine_file)
+ time_inference(engine, bs, it)
+
+def run_onnx(onnx_file, data_type, bs, inp):
+ # Create onnx_config
+ apex = onnxparser.create_onnxconfig()
+ apex.set_model_file_name(onnx_file)
+ apex.set_model_dtype(convert_to_datatype(data_type))
+
+ # create parser
+ trt_parser = onnxparser.create_onnxparser(apex)
+ assert(trt_parser)
+ data_type = apex.get_model_dtype()
+ onnx_filename = apex.get_model_file_name()
+ trt_parser.parse(onnx_filename, data_type)
+ trt_parser.report_parsing_info()
+ trt_parser.convert_to_trtnetwork()
+ trt_network = trt_parser.get_trtnetwork()
+ assert(trt_network)
+
+ # create infer builder
+ trt_builder = trt.infer.create_infer_builder(G_LOGGER)
+ trt_builder.set_max_batch_size(max_batch_size)
+ trt_builder.set_max_workspace_size(max_workspace_size)
+
+ if (apex.get_model_dtype() == trt.infer.DataType_kHALF):
+ print("------------------- Running FP16 -----------------------------")
+ trt_builder.set_half2_mode(True)
+ elif (apex.get_model_dtype() == trt.infer.DataType_kINT8):
+ print("------------------- Running INT8 -----------------------------")
+ trt_builder.set_int8_mode(True)
+ else:
+ print("------------------- Running FP32 -----------------------------")
+
+ print("----- Builder is Done -----")
+ print("----- Creating Engine -----")
+ trt_engine = trt_builder.build_cuda_engine(trt_network)
+ print("----- Engine is built -----")
+ time_inference(engine, bs, inp)
diff --git a/scripts/test_1024p.sh b/scripts/test_1024p.sh
index 99c1e24..319803c 100755
--- a/scripts/test_1024p.sh
+++ b/scripts/test_1024p.sh
@@ -1,3 +1,4 @@
-################################ Testing ################################
-# labels only
-python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none \ No newline at end of file
+#!/bin/bash
+################################ Testing ################################
+# labels only
+python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none $@
diff --git a/test.py b/test.py
index a9c8729..203d887 100755
--- a/test.py
+++ b/test.py
@@ -8,6 +8,8 @@ from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
from util import html
+import torch
+from run_engine import run_trt_engine, run_onnx
opt = TestOptions().parse(save=False)
opt.nThreads = 1 # test code only supports nThreads = 1
@@ -17,16 +19,46 @@ opt.no_flip = True # no flip
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
-model = create_model(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
+
+if not opt.engine and not opt.onnx:
+ model = create_model(opt)
+ if opt.data_type == 16:
+ model.half()
+ elif opt.data_type == 8:
+ model.type(torch.uint8)
+
+ if opt.verbose:
+ print(model)
+
+
for i, data in enumerate(dataset):
if i >= opt.how_many:
break
- generated = model.inference(data['label'], data['inst'])
+ if opt.data_type == 16:
+ data['label'] = data['label'].half()
+ data['inst'] = data['inst'].half()
+ elif opt.data_type == 8:
+ data['label'] = data['label'].uint8()
+ data['inst'] = data['inst'].uint8()
+ if opt.export_onnx:
+ print ("Exporting to ONNX: ", opt.export_onnx)
+ assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
+ torch.onnx.export(model, [data['label'], data['inst']],
+ opt.export_onnx, verbose=True)
+ exit(0)
+ minibatch = 1
+ if opt.engine:
+ generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
+ elif opt.onnx:
+ generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
+ else:
+ generated = model.inference(data['label'], data['inst'])
+
visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
('synthesized_image', util.tensor2im(generated.data[0]))])
img_path = data['path']