summaryrefslogtreecommitdiff
path: root/run_engine.py
diff options
context:
space:
mode:
authorBoris Fomitchev <bfomitchev@nvidia.com>2018-05-08 20:18:10 -0700
committerBoris Fomitchev <bfomitchev@nvidia.com>2018-05-08 20:18:10 -0700
commit25e205604e7eafa83867a15cfda526461fe58455 (patch)
treefcce10851fb0d1627b60cc23100659506f1462bb /run_engine.py
parent4ca6b1610f9fa65f8bd7d7c15059bfde18a2f02a (diff)
ONNX export is working
Diffstat (limited to 'run_engine.py')
-rw-r--r--run_engine.py173
1 files changed, 173 insertions, 0 deletions
diff --git a/run_engine.py b/run_engine.py
new file mode 100644
index 0000000..700494d
--- /dev/null
+++ b/run_engine.py
@@ -0,0 +1,173 @@
+import os
+import sys
+from random import randint
+import numpy as np
+import tensorrt
+
+try:
+ from PIL import Image
+ import pycuda.driver as cuda
+ import pycuda.gpuarray as gpuarray
+ import pycuda.autoinit
+ import argparse
+except ImportError as err:
+ sys.stderr.write("""ERROR: failed to import module ({})
+Please make sure you have pycuda and the example dependencies installed.
+https://wiki.tiker.net/PyCuda/Installation/Linux
+pip(3) install tensorrt[examples]
+""".format(err))
+ exit(1)
+
+try:
+ import tensorrt as trt
+ from tensorrt.parsers import caffeparser
+ from tensorrt.parsers import onnxparser
+except ImportError as err:
+ sys.stderr.write("""ERROR: failed to import module ({})
+Please make sure you have the TensorRT Library installed
+and accessible in your LD_LIBRARY_PATH
+""".format(err))
+ exit(1)
+
+
+G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO)
+
+class Profiler(trt.infer.Profiler):
+ """
+ Example Implimentation of a Profiler
+ Is identical to the Profiler class in trt.infer so it is possible
+ to just use that instead of implementing this if further
+ functionality is not needed
+ """
+ def __init__(self, timing_iter):
+ trt.infer.Profiler.__init__(self)
+ self.timing_iterations = timing_iter
+ self.profile = []
+
+ def report_layer_time(self, layerName, ms):
+ record = next((r for r in self.profile if r[0] == layerName), (None, None))
+ if record == (None, None):
+ self.profile.append((layerName, ms))
+ else:
+ self.profile[self.profile.index(record)] = (record[0], record[1] + ms)
+
+ def print_layer_times(self):
+ totalTime = 0
+ for i in range(len(self.profile)):
+ print("{:40.40} {:4.3f}ms".format(self.profile[i][0], self.profile[i][1] / self.timing_iterations))
+ totalTime += self.profile[i][1]
+ print("Time over all layers: {:4.2f} ms per iteration".format(totalTime / self.timing_iterations))
+
+
+def get_input_output_names(trt_engine):
+ nbindings = trt_engine.get_nb_bindings();
+ maps = []
+
+ for b in range(0, nbindings):
+ dims = trt_engine.get_binding_dimensions(b).to_DimsCHW()
+ name = trt_engine.get_binding_name(b)
+ type = trt_engine.get_binding_data_type(b)
+
+ if (trt_engine.binding_is_input(b)):
+ maps.append(name)
+ print("Found input: ", name)
+ else:
+ maps.append(name)
+ print("Found output: ", name)
+
+ print("shape=" + str(dims.C()) + " , " + str(dims.H()) + " , " + str(dims.W()))
+ print("dtype=" + str(type))
+ return maps
+
+def create_memory(engine, name, buf, mem, batchsize, inp, inp_idx):
+ binding_idx = engine.get_binding_index(name)
+ if binding_idx == -1:
+ raise AttributeError("Not a valid binding")
+ print("Binding: name={}, bindingIndex={}".format(name, str(binding_idx)))
+ dims = engine.get_binding_dimensions(binding_idx).to_DimsCHW()
+ eltCount = dims.C() * dims.H() * dims.W() * batchsize
+
+ if engine.binding_is_input(binding_idx):
+ h_mem = inp[inp_idx]
+ inp_idx = inp_idx + 1
+ else:
+ h_mem = np.random.uniform(0.0, 255.0, eltCount).astype(np.dtype('f4'))
+
+ d_mem = cuda.mem_alloc(eltCount * 4)
+ cuda.memcpy_htod(d_mem, h_mem)
+ buf.insert(binding_idx, int(d_mem))
+ mem.append(d_mem)
+ return inp_idx
+
+
+#Run inference on device
+def time_inference(engine, batch_size, inp):
+ bindings = []
+ mem = []
+ inp_idx = 0
+ for io in get_input_output_names(engine):
+ inp_idx = create_memory(engine, io, bindings, mem,
+ batch_size, inp, inp_idx)
+
+ context = engine.create_execution_context()
+ g_prof = Profiler(500)
+ context.set_profiler(g_prof)
+ for i in range(iter):
+ context.execute(batch_size, bindings)
+ g_prof.print_layer_times()
+
+ context.destroy()
+ return
+
+
+def convert_to_datatype(v):
+ if v==8:
+ return trt.infer.DataType.INT8
+ elif v==16:
+ return trt.infer.DataType.HALF
+ elif v==32:
+ return trt.infer.DataType.FLOAT
+ else:
+ print("ERROR: Invalid model data type bit depth: " + str(v))
+ return trt.infer.DataType.INT8
+
+def run_trt_engine(engine_file, bs, it):
+ engine = trt.utils.load_engine(G_LOGGER, engine_file)
+ time_inference(engine, bs, it)
+
+def run_onnx(onnx_file, data_type, bs, inp):
+ # Create onnx_config
+ apex = onnxparser.create_onnxconfig()
+ apex.set_model_file_name(onnx_file)
+ apex.set_model_dtype(convert_to_datatype(data_type))
+
+ # create parser
+ trt_parser = onnxparser.create_onnxparser(apex)
+ assert(trt_parser)
+ data_type = apex.get_model_dtype()
+ onnx_filename = apex.get_model_file_name()
+ trt_parser.parse(onnx_filename, data_type)
+ trt_parser.report_parsing_info()
+ trt_parser.convert_to_trtnetwork()
+ trt_network = trt_parser.get_trtnetwork()
+ assert(trt_network)
+
+ # create infer builder
+ trt_builder = trt.infer.create_infer_builder(G_LOGGER)
+ trt_builder.set_max_batch_size(max_batch_size)
+ trt_builder.set_max_workspace_size(max_workspace_size)
+
+ if (apex.get_model_dtype() == trt.infer.DataType_kHALF):
+ print("------------------- Running FP16 -----------------------------")
+ trt_builder.set_half2_mode(True)
+ elif (apex.get_model_dtype() == trt.infer.DataType_kINT8):
+ print("------------------- Running INT8 -----------------------------")
+ trt_builder.set_int8_mode(True)
+ else:
+ print("------------------- Running FP32 -----------------------------")
+
+ print("----- Builder is Done -----")
+ print("----- Creating Engine -----")
+ trt_engine = trt_builder.build_cuda_engine(trt_network)
+ print("----- Engine is built -----")
+ time_inference(engine, bs, inp)