16cd8b28da1cd42f9bacbef9d55f64c7f5162bb9cAndrew Harp# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 237606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# 337606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# Licensed under the Apache License, Version 2.0 (the "License"); 437606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# you may not use this file except in compliance with the License. 537606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# You may obtain a copy of the License at 637606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# 737606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# http://www.apache.org/licenses/LICENSE-2.0 837606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# 937606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# Unless required by applicable law or agreed to in writing, software 1037606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# distributed under the License is distributed on an "AS IS" BASIS, 1137606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1237606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# See the License for the specific language governing permissions and 1337606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan# limitations under the License. 14d28d4c477b764019b763029145bd81bb491e8a7cA. Unique TensorFlower"""Example of Estimator for Iris plant dataset.""" 15d28d4c477b764019b763029145bd81bb491e8a7cA. Unique TensorFlower 16334702e19a920ac21fbbbf5b14f7619cb860c427Martin Wickefrom __future__ import absolute_import 17334702e19a920ac21fbbbf5b14f7619cb860c427Martin Wickefrom __future__ import division 18334702e19a920ac21fbbbf5b14f7619cb860c427Martin Wickefrom __future__ import print_function 19a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlowerimport numpy as np 20c00c073f52c2fc7b6672022c75d0b2abb9d9af3aA. Unique TensorFlowerfrom sklearn import datasets 21c00c073f52c2fc7b6672022c75d0b2abb9d9af3aA. Unique TensorFlowerfrom sklearn import metrics 22a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlowerfrom sklearn import model_selection 23a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlowerimport tensorflow as tf 24e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney 2537606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan 26a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlowerX_FEATURE = 'x' # Name of the input feature. 2737606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan 28d28d4c477b764019b763029145bd81bb491e8a7cA. Unique TensorFlower 29a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlowerdef my_model(features, labels, mode): 30a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower """DNN with three hidden layers, and dropout of 0.1 probability.""" 31d28d4c477b764019b763029145bd81bb491e8a7cA. Unique TensorFlower # Create three fully connected layers respectively of size 10, 20, and 10 with 32d28d4c477b764019b763029145bd81bb491e8a7cA. Unique TensorFlower # each layer having a dropout probability of 0.1. 33a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower net = features[X_FEATURE] 34a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower for units in [10, 20, 10]: 35a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower net = tf.layers.dense(net, units=units, activation=tf.nn.relu) 36a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower net = tf.layers.dropout(net, rate=0.1) 37a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower 38a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower # Compute logits (1 per class). 39a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower logits = tf.layers.dense(net, 3, activation=None) 40a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower 41a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower # Compute predictions. 42a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower predicted_classes = tf.argmax(logits, 1) 43a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower if mode == tf.estimator.ModeKeys.PREDICT: 44a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower predictions = { 45a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower 'class': predicted_classes, 46a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower 'prob': tf.nn.softmax(logits) 47a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower } 48a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower return tf.estimator.EstimatorSpec(mode, predictions=predictions) 49a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower 50a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower # Compute loss. 51f79c39e9c8291787718015318b396bd11ff7ae71Mark Daoust loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 52a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower 53a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower # Create training op. 54a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower if mode == tf.estimator.ModeKeys.TRAIN: 55a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower optimizer = tf.train.AdagradOptimizer(learning_rate=0.1) 56a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 57a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) 58a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower 59a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower # Compute evaluation metrics. 60a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower eval_metric_ops = { 61a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower 'accuracy': tf.metrics.accuracy( 62a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower labels=labels, predictions=predicted_classes) 63a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower } 64a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower return tf.estimator.EstimatorSpec( 65a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower mode, loss=loss, eval_metric_ops=eval_metric_ops) 66a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower 67a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower 68a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlowerdef main(unused_argv): 69a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower iris = datasets.load_iris() 70a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower x_train, x_test, y_train, y_test = model_selection.train_test_split( 71a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower iris.data, iris.target, test_size=0.2, random_state=42) 72a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower 73a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower classifier = tf.estimator.Estimator(model_fn=my_model) 74a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower 75a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower # Train. 76a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower train_input_fn = tf.estimator.inputs.numpy_input_fn( 77a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True) 78a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower classifier.train(input_fn=train_input_fn, steps=1000) 79a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower 80a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower # Predict. 81a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower test_input_fn = tf.estimator.inputs.numpy_input_fn( 82a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False) 83a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower predictions = classifier.predict(input_fn=test_input_fn) 84a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower y_predicted = np.array(list(p['class'] for p in predictions)) 85a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower y_predicted = y_predicted.reshape(np.array(y_test).shape) 86a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower 87a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower # Score with sklearn. 88a9d231a5e0749d01d1c3004c410756578e7faccaA. Unique TensorFlower score = metrics.accuracy_score(y_test, y_predicted) 89a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower print('Accuracy (sklearn): {0:f}'.format(score)) 90a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower 91a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower # Score with tensorflow. 92a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower scores = classifier.evaluate(input_fn=test_input_fn) 93a132b8330039e7ed326d090cdae35c97561f68b1A. Unique TensorFlower print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy'])) 94a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower 9537606a4c63364c56a0834d281023b62d2bda6cd8Vijay Vasudevan 96a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlowerif __name__ == '__main__': 97a3cdbde19fbaa959e559596a555c054b78779ee5A. Unique TensorFlower tf.app.run() 98