| 1 | 1 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,75 @@ |
| 1 |
+import tensorflow as tf |
|
| 2 |
+from tensorflow.contrib import tensorrt as trt |
|
| 3 |
+ |
|
| 4 |
+import numpy as np |
|
| 5 |
+ |
|
| 6 |
+def get_simple_graph_def(): |
|
| 7 |
+ """Create a simple graph and return its graph_def.""" |
|
| 8 |
+ g = tf.Graph() |
|
| 9 |
+ with g.as_default(): |
|
| 10 |
+ a = tf.placeholder( |
|
| 11 |
+ dtype=tf.float32, shape=(None, 24, 24, 2), name="input") |
|
| 12 |
+ e = tf.constant( |
|
| 13 |
+ [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], |
|
| 14 |
+ name="weights", |
|
| 15 |
+ dtype=tf.float32) |
|
| 16 |
+ conv = tf.nn.conv2d( |
|
| 17 |
+ input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") |
|
| 18 |
+ b = tf.constant( |
|
| 19 |
+ [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=tf.float32) |
|
| 20 |
+ t = tf.nn.bias_add(conv, b, name="biasAdd") |
|
| 21 |
+ relu = tf.nn.relu(t, "relu") |
|
| 22 |
+ idty = tf.identity(relu, "ID") |
|
| 23 |
+ v = tf.nn.max_pool( |
|
| 24 |
+ idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") |
|
| 25 |
+ tf.squeeze(v, name="output") |
|
| 26 |
+ return g.as_graph_def() |
|
| 27 |
+ |
|
| 28 |
+def run_graph(gdef, dumm_inp): |
|
| 29 |
+ """Run given graphdef once.""" |
|
| 30 |
+ gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50) |
|
| 31 |
+ tf.reset_default_graph() |
|
| 32 |
+ g = tf.Graph() |
|
| 33 |
+ with g.as_default(): |
|
| 34 |
+ inp, out = tf.graph_util.import_graph_def( |
|
| 35 |
+ graph_def=gdef, return_elements=["input", "output"]) |
|
| 36 |
+ inp = inp.outputs[0] |
|
| 37 |
+ out = out.outputs[0] |
|
| 38 |
+ with tf.Session( |
|
| 39 |
+ config=tf.ConfigProto(gpu_options=gpu_options), graph=g) as sess: |
|
| 40 |
+ val = sess.run(out, {inp: dumm_inp})
|
|
| 41 |
+ return val |
|
| 42 |
+ |
|
| 43 |
+inp_dims = (100, 24, 24, 2) |
|
| 44 |
+dummy_input = np.random.random_sample(inp_dims) |
|
| 45 |
+ |
|
| 46 |
+orig_graph = get_simple_graph_def() # use a frozen graph for inference |
|
| 47 |
+ |
|
| 48 |
+trt_graph = trt.create_inference_graph( |
|
| 49 |
+ input_graph_def=orig_graph, |
|
| 50 |
+ outputs=["output"], |
|
| 51 |
+ max_batch_size=inp_dims[0], |
|
| 52 |
+ max_workspace_size_bytes=1 << 25, |
|
| 53 |
+ precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" |
|
| 54 |
+ minimum_segment_size=2 # minimum number of nodes in an engine |
|
| 55 |
+) |
|
| 56 |
+ |
|
| 57 |
+o1 = run_graph(orig_graph, dummy_input) |
|
| 58 |
+o2 = run_graph(trt_graph, dummy_input) |
|
| 59 |
+ |
|
| 60 |
+assert np.array_equal(o1, o2) |
|
| 61 |
+ |
|
| 62 |
+int8_calib_gdef = trt.create_inference_graph( |
|
| 63 |
+ input_graph_def=orig_graph, |
|
| 64 |
+ outputs=["output"], |
|
| 65 |
+ max_batch_size=inp_dims[0], |
|
| 66 |
+ max_workspace_size_bytes=1 << 25, |
|
| 67 |
+ precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" |
|
| 68 |
+ minimum_segment_size=2 # minimum number of nodes in an engine |
|
| 69 |
+) |
|
| 70 |
+ |
|
| 71 |
+int8_graph = int8_calib_gdef |
|
| 72 |
+#int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) |
|
| 73 |
+o5 = run_graph(int8_graph, dummy_input) |
|
| 74 |
+ |
|
| 75 |
+assert np.allclose(o1, o5) |