1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Simple MNIST classifier example with JIT XLA and timelines.
16
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import argparse
23import sys
24
25import tensorflow as tf
26
27from tensorflow.examples.tutorials.mnist import input_data
28from tensorflow.python.client import timeline
29
30FLAGS = None
31
32
33def main(_):
34  # Import data
35  mnist = input_data.read_data_sets(FLAGS.data_dir)
36
37  # Create the model
38  x = tf.placeholder(tf.float32, [None, 784])
39  w = tf.Variable(tf.zeros([784, 10]))
40  b = tf.Variable(tf.zeros([10]))
41  y = tf.matmul(x, w) + b
42
43  # Define loss and optimizer
44  y_ = tf.placeholder(tf.int64, [None])
45
46  # The raw formulation of cross-entropy,
47  #
48  #   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
49  #                                 reduction_indices=[1]))
50  #
51  # can be numerically unstable.
52  #
53  # So here we use tf.losses.sparse_softmax_cross_entropy on the raw
54  # logit outputs of 'y', and then average across the batch.
55  cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
56  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
57
58  config = tf.ConfigProto()
59  jit_level = 0
60  if FLAGS.xla:
61    # Turns on XLA JIT compilation.
62    jit_level = tf.OptimizerOptions.ON_1
63
64  config.graph_options.optimizer_options.global_jit_level = jit_level
65  run_metadata = tf.RunMetadata()
66  sess = tf.Session(config=config)
67  tf.global_variables_initializer().run(session=sess)
68  # Train
69  train_loops = 1000
70  for i in range(train_loops):
71    batch_xs, batch_ys = mnist.train.next_batch(100)
72
73    # Create a timeline for the last loop and export to json to view with
74    # chrome://tracing/.
75    if i == train_loops - 1:
76      sess.run(train_step,
77               feed_dict={x: batch_xs,
78                          y_: batch_ys},
79               options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
80               run_metadata=run_metadata)
81      trace = timeline.Timeline(step_stats=run_metadata.step_stats)
82      with open('timeline.ctf.json', 'w') as trace_file:
83        trace_file.write(trace.generate_chrome_trace_format())
84    else:
85      sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
86
87  # Test trained model
88  correct_prediction = tf.equal(tf.argmax(y, 1), y_)
89  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
90  print(sess.run(accuracy,
91                 feed_dict={x: mnist.test.images,
92                            y_: mnist.test.labels}))
93  sess.close()
94
95
96if __name__ == '__main__':
97  parser = argparse.ArgumentParser()
98  parser.add_argument(
99      '--data_dir',
100      type=str,
101      default='/tmp/tensorflow/mnist/input_data',
102      help='Directory for storing input data')
103  parser.add_argument(
104      '--xla', type=bool, default=True, help='Turn xla via JIT on')
105  FLAGS, unparsed = parser.parse_known_args()
106  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
107