import tensorflow as tf
def get_simple_graph_def():
"""Create a simple graph and return its graph_def."""
g = tf.Graph()
with g.as_default():
a = tf.placeholder(
dtype=tf.float32, shape=(None, 24, 24, 2), name="input")
e = tf.constant(
[[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
name="weights",
dtype=tf.float32)
conv = tf.nn.conv2d(
input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
b = tf.constant(
[4., 1.5, 2., 3., 5., 7.], name="bias", dtype=tf.float32)
t = tf.nn.bias_add(conv, b, name="biasAdd")
relu = tf.nn.relu(t, "relu")
idty = tf.identity(relu, "ID")
v = tf.nn.max_pool(
idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
tf.squeeze(v, name="output")
return g.as_graph_def()