1#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3#  Licensed under the Apache License, Version 2.0 (the "License");
4#  you may not use this file except in compliance with the License.
5#  You may obtain a copy of the License at
6#
7#   http://www.apache.org/licenses/LICENSE-2.0
8#
9#  Unless required by applicable law or agreed to in writing, software
10#  distributed under the License is distributed on an "AS IS" BASIS,
11#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#  See the License for the specific language governing permissions and
13#  limitations under the License.
14"""Example of DNNClassifier for Iris plant dataset.
15
16This example uses APIs in Tensorflow 1.4 or above.
17"""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import os
24import urllib
25
26import tensorflow as tf
27
28# Data sets
29IRIS_TRAINING = 'iris_training.csv'
30IRIS_TRAINING_URL = 'http://download.tensorflow.org/data/iris_training.csv'
31
32IRIS_TEST = 'iris_test.csv'
33IRIS_TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv'
34
35FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
36
37
38def maybe_download_iris_data(file_name, download_url):
39  """Downloads the file and returns the number of data."""
40  if not os.path.exists(file_name):
41    raw = urllib.urlopen(download_url).read()
42    with open(file_name, 'w') as f:
43      f.write(raw)
44
45  # The first line is a comma-separated string. The first one is the number of
46  # total data in the file.
47  with open(file_name, 'r') as f:
48    first_line = f.readline()
49  num_elements = first_line.split(',')[0]
50  return int(num_elements)
51
52
53def input_fn(file_name, num_data, batch_size, is_training):
54  """Creates an input_fn required by Estimator train/evaluate."""
55  # If the data sets aren't stored locally, download them.
56
57  def _parse_csv(rows_string_tensor):
58    """Takes the string input tensor and returns tuple of (features, labels)."""
59    # Last dim is the label.
60    num_features = len(FEATURE_KEYS)
61    num_columns = num_features + 1
62    columns = tf.decode_csv(rows_string_tensor,
63                            record_defaults=[[]] * num_columns)
64    features = dict(zip(FEATURE_KEYS, columns[:num_features]))
65    labels = tf.cast(columns[num_features], tf.int32)
66    return features, labels
67
68  def _input_fn():
69    """The input_fn."""
70    dataset = tf.data.TextLineDataset([file_name])
71    # Skip the first line (which does not have data).
72    dataset = dataset.skip(1)
73    dataset = dataset.map(_parse_csv)
74
75    if is_training:
76      # For this small dataset, which can fit into memory, to achieve true
77      # randomness, the shuffle buffer size is set as the total number of
78      # elements in the dataset.
79      dataset = dataset.shuffle(num_data)
80      dataset = dataset.repeat()
81
82    dataset = dataset.batch(batch_size)
83    iterator = dataset.make_one_shot_iterator()
84    features, labels = iterator.get_next()
85    return features, labels
86
87  return _input_fn
88
89
90def main(unused_argv):
91  tf.logging.set_verbosity(tf.logging.INFO)
92
93  num_training_data = maybe_download_iris_data(
94      IRIS_TRAINING, IRIS_TRAINING_URL)
95  num_test_data = maybe_download_iris_data(IRIS_TEST, IRIS_TEST_URL)
96
97  # Build 3 layer DNN with 10, 20, 10 units respectively.
98  feature_columns = [
99      tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS]
100  classifier = tf.estimator.DNNClassifier(
101      feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3)
102
103  # Train.
104  train_input_fn = input_fn(IRIS_TRAINING, num_training_data, batch_size=32,
105                            is_training=True)
106  classifier.train(input_fn=train_input_fn, steps=400)
107
108  # Eval.
109  test_input_fn = input_fn(IRIS_TEST, num_test_data, batch_size=32,
110                           is_training=False)
111  scores = classifier.evaluate(input_fn=test_input_fn)
112  print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
113
114
115if __name__ == '__main__':
116  tf.app.run()
117