1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Example of debugging TensorFlow runtime errors using tfdbg."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import argparse
21import sys
22
23import numpy as np
24import tensorflow as tf
25
26from tensorflow.python import debug as tf_debug
27
28
29def main(_):
30  sess = tf.Session()
31
32  # Construct the TensorFlow network.
33  ph_float = tf.placeholder(tf.float32, name="ph_float")
34  x = tf.transpose(ph_float, name="x")
35  v = tf.Variable(np.array([[-2.0], [-3.0], [6.0]], dtype=np.float32), name="v")
36  m = tf.constant(
37      np.array([[0.0, 1.0, 2.0], [-4.0, -1.0, 0.0]]),
38      dtype=tf.float32,
39      name="m")
40  y = tf.matmul(m, x, name="y")
41  z = tf.matmul(m, v, name="z")
42
43  if FLAGS.debug:
44    sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type)
45
46  if FLAGS.error == "shape_mismatch":
47    print(sess.run(y, feed_dict={ph_float: np.array([[0.0], [1.0], [2.0]])}))
48  elif FLAGS.error == "uninitialized_variable":
49    print(sess.run(z))
50  elif FLAGS.error == "no_error":
51    print(sess.run(y, feed_dict={ph_float: np.array([[0.0, 1.0, 2.0]])}))
52  else:
53    raise ValueError("Unrecognized error type: " + FLAGS.error)
54
55
56if __name__ == "__main__":
57  parser = argparse.ArgumentParser()
58  parser.register("type", "bool", lambda v: v.lower() == "true")
59  parser.add_argument(
60      "--error",
61      type=str,
62      default="shape_mismatch",
63      help="""\
64      Type of the error to generate (shape_mismatch | uninitialized_variable |
65      no_error).\
66      """)
67  parser.add_argument(
68      "--ui_type",
69      type=str,
70      default="curses",
71      help="Command-line user interface type (curses | readline)")
72  parser.add_argument(
73      "--debug",
74      type="bool",
75      nargs="?",
76      const=True,
77      default=False,
78      help="Use debugger to track down bad values during training")
79  FLAGS, unparsed = parser.parse_known_args()
80  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
81