Browse code

함수 분리

Ikseon Kang authored on28/06/2019 15:03:29
Showing3 changed files

1 1
new file mode 100644
... ...
@@ -0,0 +1 @@
1
+*.swp
0 2
new file mode 100644
... ...
@@ -0,0 +1,23 @@
1
+import tensorflow as tf
2
+
3
+def get_simple_graph_def():
4
+  """Create a simple graph and return its graph_def."""
5
+  g = tf.Graph()
6
+  with g.as_default():
7
+    a = tf.placeholder(
8
+        dtype=tf.float32, shape=(None, 24, 24, 2), name="input")
9
+    e = tf.constant(
10
+        [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
11
+        name="weights",
12
+        dtype=tf.float32)
13
+    conv = tf.nn.conv2d(
14
+        input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
15
+    b = tf.constant(
16
+        [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=tf.float32)
17
+    t = tf.nn.bias_add(conv, b, name="biasAdd")
18
+    relu = tf.nn.relu(t, "relu")
19
+    idty = tf.identity(relu, "ID")
20
+    v = tf.nn.max_pool(
21
+        idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
22
+    tf.squeeze(v, name="output")
23
+  return g.as_graph_def()
... ...
@@ -3,27 +3,7 @@ from tensorflow.contrib import tensorrt as trt
3 3
 
4 4
 import numpy as np
5 5
 
6
-def get_simple_graph_def():
7
-  """Create a simple graph and return its graph_def."""
8
-  g = tf.Graph()
9
-  with g.as_default():
10
-    a = tf.placeholder(
11
-        dtype=tf.float32, shape=(None, 24, 24, 2), name="input")
12
-    e = tf.constant(
13
-        [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
14
-        name="weights",
15
-        dtype=tf.float32)
16
-    conv = tf.nn.conv2d(
17
-        input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
18
-    b = tf.constant(
19
-        [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=tf.float32)
20
-    t = tf.nn.bias_add(conv, b, name="biasAdd")
21
-    relu = tf.nn.relu(t, "relu")
22
-    idty = tf.identity(relu, "ID")
23
-    v = tf.nn.max_pool(
24
-        idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
25
-    tf.squeeze(v, name="output")
26
-  return g.as_graph_def()
6
+from lib import get_simple_graph_def
27 7
 
28 8
 def run_graph(gdef, dumm_inp):
29 9
   """Run given graphdef once."""