import tensorflow as tf
import numpy as np
import iris_data

(train_x, train_y), (test_x, test_y) = iris_data.load_data()

print (np.shape(train_x))
print (np.shape(train_y))

feature_columns = []
for key in train_x.keys():
    feature_columns.append(tf.feature_column.numeric_column(key))

classifier = tf.estimator.DNNClassifier(
    feature_columns = feature_columns,
    hidden_units=[10, 10],
    n_classes = 3,
)

batch_size = 100
steps = 1000

def train_input_fn(x, y, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((dict(x), y))
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    return dataset

classifier.train(input_fn = lambda:train_input_fn(train_x, train_y, batch_size), steps = steps)

def test_input_fn(x, y, batch_size):
    x=dict(x)
    if y is None:
        inputs = x
    else:
        inputs = (x, y)

    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    dataset = dataset.batch(batch_size)

    return dataset

result = classifier.evaluate(input_fn = lambda:iris_data.eval_input_fn(test_x, test_y, batch_size))
print (result)

predict_x = {
    'SepalLength': [5.1, 5.9, 6.9],
    'SepalWidth': [3.3, 3.0, 3.1],
    'PetalLength': [1.7, 4.2, 5.4],
    'PetalWidth': [0.5, 1.5, 2.1],
}

predictions = classifier.predict(input_fn=lambda:test_input_fn(predict_x, y=None, batch_size=batch_size))

for p in predictions:
    print (p)