0 | 2 |
new file mode 100644 |
... | ... |
@@ -0,0 +1,23 @@ |
1 |
+import tensorflow as tf |
|
2 |
+ |
|
3 |
+def get_simple_graph_def(): |
|
4 |
+ """Create a simple graph and return its graph_def.""" |
|
5 |
+ g = tf.Graph() |
|
6 |
+ with g.as_default(): |
|
7 |
+ a = tf.placeholder( |
|
8 |
+ dtype=tf.float32, shape=(None, 24, 24, 2), name="input") |
|
9 |
+ e = tf.constant( |
|
10 |
+ [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], |
|
11 |
+ name="weights", |
|
12 |
+ dtype=tf.float32) |
|
13 |
+ conv = tf.nn.conv2d( |
|
14 |
+ input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") |
|
15 |
+ b = tf.constant( |
|
16 |
+ [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=tf.float32) |
|
17 |
+ t = tf.nn.bias_add(conv, b, name="biasAdd") |
|
18 |
+ relu = tf.nn.relu(t, "relu") |
|
19 |
+ idty = tf.identity(relu, "ID") |
|
20 |
+ v = tf.nn.max_pool( |
|
21 |
+ idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") |
|
22 |
+ tf.squeeze(v, name="output") |
|
23 |
+ return g.as_graph_def() |
... | ... |
@@ -3,27 +3,7 @@ from tensorflow.contrib import tensorrt as trt |
3 | 3 |
|
4 | 4 |
import numpy as np |
5 | 5 |
|
6 |
-def get_simple_graph_def(): |
|
7 |
- """Create a simple graph and return its graph_def.""" |
|
8 |
- g = tf.Graph() |
|
9 |
- with g.as_default(): |
|
10 |
- a = tf.placeholder( |
|
11 |
- dtype=tf.float32, shape=(None, 24, 24, 2), name="input") |
|
12 |
- e = tf.constant( |
|
13 |
- [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], |
|
14 |
- name="weights", |
|
15 |
- dtype=tf.float32) |
|
16 |
- conv = tf.nn.conv2d( |
|
17 |
- input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") |
|
18 |
- b = tf.constant( |
|
19 |
- [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=tf.float32) |
|
20 |
- t = tf.nn.bias_add(conv, b, name="biasAdd") |
|
21 |
- relu = tf.nn.relu(t, "relu") |
|
22 |
- idty = tf.identity(relu, "ID") |
|
23 |
- v = tf.nn.max_pool( |
|
24 |
- idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") |
|
25 |
- tf.squeeze(v, name="output") |
|
26 |
- return g.as_graph_def() |
|
6 |
+from lib import get_simple_graph_def |
|
27 | 7 |
|
28 | 8 |
def run_graph(gdef, dumm_inp): |
29 | 9 |
"""Run given graphdef once.""" |