| 1 | 1 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,64 @@ |
| 1 |
+import tensorflow as tf |
|
| 2 |
+import iris_data |
|
| 3 |
+ |
|
| 4 |
+tf.logging.set_verbosity(tf.logging.INFO) |
|
| 5 |
+ |
|
| 6 |
+def model_fn(features, labels, mode, params): |
|
| 7 |
+ net = tf.feature_column.input_layer(features, params['feature_columns']) |
|
| 8 |
+ |
|
| 9 |
+ for units in params['hidden_units']: |
|
| 10 |
+ net = tf.layers.dense(net, units=units, activation=tf.nn.relu) |
|
| 11 |
+ |
|
| 12 |
+ logits = tf.layers.dense(net, units=params['n_classes'], activation=None) |
|
| 13 |
+ |
|
| 14 |
+ predictions = {'probabilities':tf.nn.softmax(logits, 1)}
|
|
| 15 |
+ |
|
| 16 |
+ if mode == tf.estimator.ModeKeys.PREDICT: |
|
| 17 |
+ return tf.estimator.EstimatorSpec(mode, predictions = predictions) |
|
| 18 |
+ |
|
| 19 |
+ loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) |
|
| 20 |
+ |
|
| 21 |
+ eval_metric_ops = {'accuracy':tf.metrics.accuracy(labels=labels, predictions=tf.argmax(logits, 1))}
|
|
| 22 |
+ |
|
| 23 |
+ if mode == tf.estimator.ModeKeys.EVAL: |
|
| 24 |
+ return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops = eval_metric_ops) |
|
| 25 |
+ |
|
| 26 |
+ # Create training op. |
|
| 27 |
+ assert mode == tf.estimator.ModeKeys.TRAIN |
|
| 28 |
+ |
|
| 29 |
+ optimizer = tf.train.AdagradOptimizer(learning_rate=0.1) |
|
| 30 |
+ train_op = optimizer.minimize(loss=loss, global_step = tf.train.get_global_step()) |
|
| 31 |
+ |
|
| 32 |
+ return tf.estimator.EstimatorSpec(mode, loss=loss, train_op = train_op) |
|
| 33 |
+ |
|
| 34 |
+train, test = iris_data.load_data() |
|
| 35 |
+ |
|
| 36 |
+train_x, train_y = train |
|
| 37 |
+test_x, test_y = test |
|
| 38 |
+ |
|
| 39 |
+feature_columns = [] |
|
| 40 |
+ |
|
| 41 |
+for key in train_x.keys(): |
|
| 42 |
+ feature_columns.append(tf.feature_column.numeric_column(key)) |
|
| 43 |
+ |
|
| 44 |
+classifier = tf.estimator.Estimator( |
|
| 45 |
+ model_fn = model_fn, |
|
| 46 |
+ params = {"n_classes":3, "feature_columns":feature_columns, "hidden_units":[10, 10]}
|
|
| 47 |
+) |
|
| 48 |
+ |
|
| 49 |
+classifier.train(input_fn = lambda:iris_data.train_input_fn(train_x, train_y, 100), steps=1000) |
|
| 50 |
+ |
|
| 51 |
+result = classifier.evaluate(input_fn = lambda : iris_data.eval_input_fn(test_x, test_y, 100)) |
|
| 52 |
+print result |
|
| 53 |
+ |
|
| 54 |
+predict_x = {
|
|
| 55 |
+ 'SepalLength': [5.1, 5.9, 6.9], |
|
| 56 |
+ 'SepalWidth': [3.3, 3.0, 3.1], |
|
| 57 |
+ 'PetalLength': [1.7, 4.2, 5.4], |
|
| 58 |
+ 'PetalWidth': [0.5, 1.5, 2.1], |
|
| 59 |
+} |
|
| 60 |
+ |
|
| 61 |
+predictions = classifier.predict(input_fn = lambda : iris_data.eval_input_fn(predict_x, None, 100)) |
|
| 62 |
+ |
|
| 63 |
+for p in predictions: |
|
| 64 |
+ print p |