import tensorflow as tf from tensorflow.contrib import tensorrt as trt import numpy as np def get_simple_graph_def(): """Create a simple graph and return its graph_def.""" g = tf.Graph() with g.as_default(): a = tf.placeholder( dtype=tf.float32, shape=(None, 24, 24, 2), name="input") e = tf.constant( [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], name="weights", dtype=tf.float32) conv = tf.nn.conv2d( input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") b = tf.constant( [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=tf.float32) t = tf.nn.bias_add(conv, b, name="biasAdd") relu = tf.nn.relu(t, "relu") idty = tf.identity(relu, "ID") v = tf.nn.max_pool( idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") tf.squeeze(v, name="output") return g.as_graph_def() def run_graph(gdef, dumm_inp): """Run given graphdef once.""" gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50) tf.reset_default_graph() g = tf.Graph() with g.as_default(): inp, out = tf.graph_util.import_graph_def( graph_def=gdef, return_elements=["input", "output"]) inp = inp.outputs[0] out = out.outputs[0] with tf.Session( config=tf.ConfigProto(gpu_options=gpu_options), graph=g) as sess: val = sess.run(out, {inp: dumm_inp}) return val inp_dims = (100, 24, 24, 2) dummy_input = np.random.random_sample(inp_dims) orig_graph = get_simple_graph_def() # use a frozen graph for inference trt_graph = trt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" minimum_segment_size=2 # minimum number of nodes in an engine ) o1 = run_graph(orig_graph, dummy_input) o2 = run_graph(trt_graph, dummy_input) assert np.array_equal(o1, o2) int8_calib_gdef = trt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" minimum_segment_size=2 # minimum number of nodes in an engine ) int8_graph = int8_calib_gdef #int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) o5 = run_graph(int8_graph, dummy_input) assert np.allclose(o1, o5)