import tensorflow as tf
import iris_data

tf.logging.set_verbosity(tf.logging.INFO)

def model_fn(features, labels, mode, params):
    net = tf.feature_column.input_layer(features, params['feature_columns'])

    for units in params['hidden_units']:
        net = tf.layers.dense(net, units=units, activation=tf.nn.relu)

    logits = tf.layers.dense(net, units=params['n_classes'], activation=None)

    predictions = {'probabilities':tf.nn.softmax(logits, 1)}

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions = predictions)

    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

    eval_metric_ops = {'accuracy':tf.metrics.accuracy(labels=labels, predictions=tf.argmax(logits, 1))}

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops = eval_metric_ops)

    # Create training op.
    assert mode == tf.estimator.ModeKeys.TRAIN

    optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
    train_op = optimizer.minimize(loss=loss, global_step = tf.train.get_global_step())

    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op = train_op)

train, test = iris_data.load_data()

train_x, train_y = train
test_x, test_y = test

feature_columns = []

for key in train_x.keys():
    feature_columns.append(tf.feature_column.numeric_column(key))

classifier = tf.estimator.Estimator(
    model_fn = model_fn,
    params = {"n_classes":3, "feature_columns":feature_columns, "hidden_units":[10, 10]}
)

classifier.train(input_fn = lambda:iris_data.train_input_fn(train_x, train_y, 100), steps=1000)

result = classifier.evaluate(input_fn = lambda : iris_data.eval_input_fn(test_x, test_y, 100))
print (result)

predict_x = {
    'SepalLength': [5.1, 5.9, 6.9],
    'SepalWidth': [3.3, 3.0, 3.1],
    'PetalLength': [1.7, 4.2, 5.4],
    'PetalWidth': [0.5, 1.5, 2.1],
}

predictions = classifier.predict(input_fn = lambda : iris_data.eval_input_fn(predict_x, None, 100))

for p in predictions:
    print (p)