1 | 1 |
new file mode 100644 |
... | ... |
@@ -0,0 +1,75 @@ |
1 |
+import tensorflow as tf |
|
2 |
+from tensorflow.contrib import tensorrt as trt |
|
3 |
+ |
|
4 |
+import numpy as np |
|
5 |
+ |
|
6 |
+def get_simple_graph_def(): |
|
7 |
+ """Create a simple graph and return its graph_def.""" |
|
8 |
+ g = tf.Graph() |
|
9 |
+ with g.as_default(): |
|
10 |
+ a = tf.placeholder( |
|
11 |
+ dtype=tf.float32, shape=(None, 24, 24, 2), name="input") |
|
12 |
+ e = tf.constant( |
|
13 |
+ [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], |
|
14 |
+ name="weights", |
|
15 |
+ dtype=tf.float32) |
|
16 |
+ conv = tf.nn.conv2d( |
|
17 |
+ input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") |
|
18 |
+ b = tf.constant( |
|
19 |
+ [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=tf.float32) |
|
20 |
+ t = tf.nn.bias_add(conv, b, name="biasAdd") |
|
21 |
+ relu = tf.nn.relu(t, "relu") |
|
22 |
+ idty = tf.identity(relu, "ID") |
|
23 |
+ v = tf.nn.max_pool( |
|
24 |
+ idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") |
|
25 |
+ tf.squeeze(v, name="output") |
|
26 |
+ return g.as_graph_def() |
|
27 |
+ |
|
28 |
+def run_graph(gdef, dumm_inp): |
|
29 |
+ """Run given graphdef once.""" |
|
30 |
+ gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50) |
|
31 |
+ tf.reset_default_graph() |
|
32 |
+ g = tf.Graph() |
|
33 |
+ with g.as_default(): |
|
34 |
+ inp, out = tf.graph_util.import_graph_def( |
|
35 |
+ graph_def=gdef, return_elements=["input", "output"]) |
|
36 |
+ inp = inp.outputs[0] |
|
37 |
+ out = out.outputs[0] |
|
38 |
+ with tf.Session( |
|
39 |
+ config=tf.ConfigProto(gpu_options=gpu_options), graph=g) as sess: |
|
40 |
+ val = sess.run(out, {inp: dumm_inp}) |
|
41 |
+ return val |
|
42 |
+ |
|
43 |
+inp_dims = (100, 24, 24, 2) |
|
44 |
+dummy_input = np.random.random_sample(inp_dims) |
|
45 |
+ |
|
46 |
+orig_graph = get_simple_graph_def() # use a frozen graph for inference |
|
47 |
+ |
|
48 |
+trt_graph = trt.create_inference_graph( |
|
49 |
+ input_graph_def=orig_graph, |
|
50 |
+ outputs=["output"], |
|
51 |
+ max_batch_size=inp_dims[0], |
|
52 |
+ max_workspace_size_bytes=1 << 25, |
|
53 |
+ precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" |
|
54 |
+ minimum_segment_size=2 # minimum number of nodes in an engine |
|
55 |
+) |
|
56 |
+ |
|
57 |
+o1 = run_graph(orig_graph, dummy_input) |
|
58 |
+o2 = run_graph(trt_graph, dummy_input) |
|
59 |
+ |
|
60 |
+assert np.array_equal(o1, o2) |
|
61 |
+ |
|
62 |
+int8_calib_gdef = trt.create_inference_graph( |
|
63 |
+ input_graph_def=orig_graph, |
|
64 |
+ outputs=["output"], |
|
65 |
+ max_batch_size=inp_dims[0], |
|
66 |
+ max_workspace_size_bytes=1 << 25, |
|
67 |
+ precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" |
|
68 |
+ minimum_segment_size=2 # minimum number of nodes in an engine |
|
69 |
+) |
|
70 |
+ |
|
71 |
+int8_graph = int8_calib_gdef |
|
72 |
+#int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) |
|
73 |
+o5 = run_graph(int8_graph, dummy_input) |
|
74 |
+ |
|
75 |
+assert np.allclose(o1, o5) |