Browse code

실행 파일 추가

Nepirity Corp authored on17/04/2018 15:34:56
Showing1 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,93 @@
1
+import pandas as pd
2
+import tensorflow as tf
3
+
4
+TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
5
+TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
6
+
7
+CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
8
+                    'PetalLength', 'PetalWidth', 'Species']
9
+SPECIES = ['Setosa', 'Versicolor', 'Virginica']
10
+
11
+def maybe_download():
12
+    train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
13
+    test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)
14
+
15
+    return train_path, test_path
16
+
17
+def load_data(y_name='Species'):
18
+    """Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""
19
+    train_path, test_path = maybe_download()
20
+
21
+    train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
22
+    train_x, train_y = train, train.pop(y_name)
23
+
24
+    test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
25
+    test_x, test_y = test, test.pop(y_name)
26
+
27
+    return (train_x, train_y), (test_x, test_y)
28
+
29
+
30
+def train_input_fn(features, labels, batch_size):
31
+    """An input function for training"""
32
+    # Convert the inputs to a Dataset.
33
+    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
34
+
35
+    # Shuffle, repeat, and batch the examples.
36
+    dataset = dataset.shuffle(1000).repeat().batch(batch_size)
37
+
38
+    # Return the dataset.
39
+    return dataset
40
+
41
+
42
+def eval_input_fn(features, labels, batch_size):
43
+    """An input function for evaluation or prediction"""
44
+    features=dict(features)
45
+    if labels is None:
46
+        # No labels, use only features.
47
+        inputs = features
48
+    else:
49
+        inputs = (features, labels)
50
+
51
+    # Convert the inputs to a Dataset.
52
+    dataset = tf.data.Dataset.from_tensor_slices(inputs)
53
+
54
+    # Batch the examples
55
+    assert batch_size is not None, "batch_size must not be None"
56
+    dataset = dataset.batch(batch_size)
57
+
58
+    # Return the dataset.
59
+    return dataset
60
+
61
+
62
+# The remainder of this file contains a simple example of a csv parser,
63
+#     implemented using a the `Dataset` class.
64
+
65
+# `tf.parse_csv` sets the types of the outputs to match the examples given in
66
+#     the `record_defaults` argument.
67
+CSV_TYPES = [[0.0], [0.0], [0.0], [0.0], [0]]
68
+
69
+def _parse_line(line):
70
+    # Decode the line into its fields
71
+    fields = tf.decode_csv(line, record_defaults=CSV_TYPES)
72
+
73
+    # Pack the result into a dictionary
74
+    features = dict(zip(CSV_COLUMN_NAMES, fields))
75
+
76
+    # Separate the label from the features
77
+    label = features.pop('Species')
78
+
79
+    return features, label
80
+
81
+
82
+def csv_input_fn(csv_path, batch_size):
83
+    # Create a dataset containing the text lines.
84
+    dataset = tf.data.TextLineDataset(csv_path).skip(1)
85
+
86
+    # Parse each line.
87
+    dataset = dataset.map(_parse_line)
88
+
89
+    # Shuffle, repeat, and batch the examples.
90
+    dataset = dataset.shuffle(1000).repeat().batch(batch_size)
91
+
92
+    # Return the dataset.
93
+    return dataset