import tensorflow as tf
from tensorflow.contrib import tensorrt as trt

import numpy as np

from lib import get_simple_graph_def

def run_graph(gdef, dumm_inp):
  """Run given graphdef once."""
  gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50)
  tf.reset_default_graph()
  g = tf.Graph()
  with g.as_default():
    inp, out = tf.graph_util.import_graph_def(
        graph_def=gdef, return_elements=["input", "output"])
    inp = inp.outputs[0]
    out = out.outputs[0]
  with tf.Session(
      config=tf.ConfigProto(gpu_options=gpu_options), graph=g) as sess:
    val = sess.run(out, {inp: dumm_inp})
  return val

inp_dims = (100, 24, 24, 2)
dummy_input = np.random.random_sample(inp_dims)

orig_graph = get_simple_graph_def()  # use a frozen graph for inference

trt_graph = trt.create_inference_graph(
  input_graph_def=orig_graph,
  outputs=["output"],
  max_batch_size=inp_dims[0],
  max_workspace_size_bytes=1 << 25,
  precision_mode="FP32",  # TRT Engine precision "FP32","FP16" or "INT8"
  minimum_segment_size=2  # minimum number of nodes in an engine
)

o1 = run_graph(orig_graph, dummy_input)
o2 = run_graph(trt_graph, dummy_input)

assert np.array_equal(o1, o2)

int8_calib_gdef = trt.create_inference_graph(
  input_graph_def=orig_graph,
  outputs=["output"],
  max_batch_size=inp_dims[0],
  max_workspace_size_bytes=1 << 25,
  precision_mode="INT8",  # TRT Engine precision "FP32","FP16" or "INT8"
  minimum_segment_size=2  # minimum number of nodes in an engine
)

int8_graph = int8_calib_gdef
#int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)
o5 = run_graph(int8_graph, dummy_input)

assert np.allclose(o1, o5)