Browse code

custom estimator 추가

Nepirity Corp authored on02/05/2018 15:42:48
Showing1 changed files

1 1
new file mode 100644
... ...
@@ -0,0 +1,64 @@
1
+import tensorflow as tf
2
+import iris_data
3
+
4
+tf.logging.set_verbosity(tf.logging.INFO)
5
+
6
+def model_fn(features, labels, mode, params):
7
+    net = tf.feature_column.input_layer(features, params['feature_columns'])
8
+
9
+    for units in params['hidden_units']:
10
+        net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
11
+
12
+    logits = tf.layers.dense(net, units=params['n_classes'], activation=None)
13
+
14
+    predictions = {'probabilities':tf.nn.softmax(logits, 1)}
15
+
16
+    if mode == tf.estimator.ModeKeys.PREDICT:
17
+        return tf.estimator.EstimatorSpec(mode, predictions = predictions)
18
+
19
+    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
20
+
21
+    eval_metric_ops = {'accuracy':tf.metrics.accuracy(labels=labels, predictions=tf.argmax(logits, 1))}
22
+
23
+    if mode == tf.estimator.ModeKeys.EVAL:
24
+        return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops = eval_metric_ops)
25
+
26
+    # Create training op.
27
+    assert mode == tf.estimator.ModeKeys.TRAIN
28
+
29
+    optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
30
+    train_op = optimizer.minimize(loss=loss, global_step = tf.train.get_global_step())
31
+
32
+    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op = train_op)
33
+
34
+train, test = iris_data.load_data()
35
+
36
+train_x, train_y = train
37
+test_x, test_y = test
38
+
39
+feature_columns = []
40
+
41
+for key in train_x.keys():
42
+    feature_columns.append(tf.feature_column.numeric_column(key))
43
+
44
+classifier = tf.estimator.Estimator(
45
+    model_fn = model_fn,
46
+    params = {"n_classes":3, "feature_columns":feature_columns, "hidden_units":[10, 10]}
47
+)
48
+
49
+classifier.train(input_fn = lambda:iris_data.train_input_fn(train_x, train_y, 100), steps=1000)
50
+
51
+result = classifier.evaluate(input_fn = lambda : iris_data.eval_input_fn(test_x, test_y, 100))
52
+print result
53
+
54
+predict_x = {
55
+    'SepalLength': [5.1, 5.9, 6.9],
56
+    'SepalWidth': [3.3, 3.0, 3.1],
57
+    'PetalLength': [1.7, 4.2, 5.4],
58
+    'PetalWidth': [0.5, 1.5, 2.1],
59
+}
60
+
61
+predictions = classifier.predict(input_fn = lambda : iris_data.eval_input_fn(predict_x, None, 100))
62
+
63
+for p in predictions:
64
+    print p