16cd8b28da1cd42f9bacbef9d55f64c7f5162bb9cAndrew Harp#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
237606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#
337606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#  Licensed under the Apache License, Version 2.0 (the "License");
437606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#  you may not use this file except in compliance with the License.
537606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#  You may obtain a copy of the License at
637606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#
737606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#   http://www.apache.org/licenses/LICENSE-2.0
837606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#
937606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#  Unless required by applicable law or agreed to in writing, software
1037606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#  distributed under the License is distributed on an "AS IS" BASIS,
1137606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1237606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#  See the License for the specific language governing permissions and
1337606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan#  limitations under the License.
14d28d4c477b764019b763029145bd81bb491e8a7cA. Unique TensorFlower"""Example of Estimator for Iris plant dataset."""
15d28d4c477b764019b763029145bd81bb491e8a7cA. Unique TensorFlower
16334702e19a920ac21fbbbf5b14f7619cb860c427Martin Wickefrom __future__ import absolute_import
17334702e19a920ac21fbbbf5b14f7619cb860c427Martin Wickefrom __future__ import division
18334702e19a920ac21fbbbf5b14f7619cb860c427Martin Wickefrom __future__ import print_function
19a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlowerimport numpy as np
20c00c073f52c2fc7b6672022c75d0b2abb9d9af3aA. Unique TensorFlowerfrom sklearn import datasets
21c00c073f52c2fc7b6672022c75d0b2abb9d9af3aA. Unique TensorFlowerfrom sklearn import metrics
22a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlowerfrom sklearn import model_selection
23a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlowerimport tensorflow as tf
24e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney
2537606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan
26a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlowerX_FEATURE = 'x'  # Name of the input feature.
2737606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan
28d28d4c477b764019b763029145bd81bb491e8a7cA. Unique TensorFlower
29a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlowerdef my_model(features, labels, mode):
30a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  """DNN with three hidden layers, and dropout of 0.1 probability."""
31d28d4c477b764019b763029145bd81bb491e8a7cA. Unique TensorFlower  # Create three fully connected layers respectively of size 10, 20, and 10 with
32d28d4c477b764019b763029145bd81bb491e8a7cA. Unique TensorFlower  # each layer having a dropout probability of 0.1.
33a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  net = features[X_FEATURE]
34a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  for units in [10, 20, 10]:
35a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower    net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
36a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower    net = tf.layers.dropout(net, rate=0.1)
37a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower
38a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  # Compute logits (1 per class).
39a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  logits = tf.layers.dense(net, 3, activation=None)
40a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower
41a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  # Compute predictions.
42a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  predicted_classes = tf.argmax(logits, 1)
43a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  if mode == tf.estimator.ModeKeys.PREDICT:
44a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower    predictions = {
45a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower        'class': predicted_classes,
46a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower        'prob': tf.nn.softmax(logits)
47a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower    }
48a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower    return tf.estimator.EstimatorSpec(mode, predictions=predictions)
49a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower
50a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  # Compute loss.
51f79c39e9c8291787718015318b396bd11ff7ae71Mark Daoust  loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
52a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower
53a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  # Create training op.
54a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  if mode == tf.estimator.ModeKeys.TRAIN:
55a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower    optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
56a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
57a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
58a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower
59a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  # Compute evaluation metrics.
60a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  eval_metric_ops = {
61a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower      'accuracy': tf.metrics.accuracy(
62a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower          labels=labels, predictions=predicted_classes)
63a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  }
64a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  return tf.estimator.EstimatorSpec(
65a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower      mode, loss=loss, eval_metric_ops=eval_metric_ops)
66a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower
67a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower
68a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlowerdef main(unused_argv):
69a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower  iris = datasets.load_iris()
70a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  x_train, x_test, y_train, y_test = model_selection.train_test_split(
71a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower      iris.data, iris.target, test_size=0.2, random_state=42)
72a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower
73a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  classifier = tf.estimator.Estimator(model_fn=my_model)
74a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower
75a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  # Train.
76a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  train_input_fn = tf.estimator.inputs.numpy_input_fn(
77a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower      x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
78a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  classifier.train(input_fn=train_input_fn, steps=1000)
79a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower
80a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  # Predict.
81a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  test_input_fn = tf.estimator.inputs.numpy_input_fn(
82a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower      x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
83a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  predictions = classifier.predict(input_fn=test_input_fn)
84a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  y_predicted = np.array(list(p['class'] for p in predictions))
85a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  y_predicted = y_predicted.reshape(np.array(y_test).shape)
86a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower
87a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  # Score with sklearn.
88a9d231a5e0749d01d1c3004c410756578e7faccaA. Unique TensorFlower  score = metrics.accuracy_score(y_test, y_predicted)
89a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  print('Accuracy (sklearn): {0:f}'.format(score))
90a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower
91a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  # Score with tensorflow.
92a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  scores = classifier.evaluate(input_fn=test_input_fn)
93a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower  print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
94a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower
9537606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan
96a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlowerif __name__ == '__main__':
97a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower  tf.app.run()
98