1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A very simple MNIST classifier.
16
17See extensive documentation at
18https://www.tensorflow.org/get_started/mnist/beginners
19"""
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import argparse
25import sys
26
27from tensorflow.examples.tutorials.mnist import input_data
28
29import tensorflow as tf
30
31FLAGS = None
32
33
34def main(_):
35  # Import data
36  mnist = input_data.read_data_sets(FLAGS.data_dir)
37
38  # Create the model
39  x = tf.placeholder(tf.float32, [None, 784])
40  W = tf.Variable(tf.zeros([784, 10]))
41  b = tf.Variable(tf.zeros([10]))
42  y = tf.matmul(x, W) + b
43
44  # Define loss and optimizer
45  y_ = tf.placeholder(tf.int64, [None])
46
47  # The raw formulation of cross-entropy,
48  #
49  #   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
50  #                                 reduction_indices=[1]))
51  #
52  # can be numerically unstable.
53  #
54  # So here we use tf.losses.sparse_softmax_cross_entropy on the raw
55  # outputs of 'y', and then average across the batch.
56  cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
57  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
58
59  sess = tf.InteractiveSession()
60  tf.global_variables_initializer().run()
61  # Train
62  for _ in range(1000):
63    batch_xs, batch_ys = mnist.train.next_batch(100)
64    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
65
66  # Test trained model
67  correct_prediction = tf.equal(tf.argmax(y, 1), y_)
68  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
69  print(sess.run(
70      accuracy, feed_dict={
71          x: mnist.test.images,
72          y_: mnist.test.labels
73      }))
74
75
76if __name__ == '__main__':
77  parser = argparse.ArgumentParser()
78  parser.add_argument(
79      '--data_dir',
80      type=str,
81      default='/tmp/tensorflow/mnist/input_data',
82      help='Directory for storing input data')
83  FLAGS, unparsed = parser.parse_known_args()
84  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
85