1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Framework of debug-wrapped sessions."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21import shutil
22import tempfile
23import threading
24
25import numpy as np
26
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.core.protobuf import rewriter_config_pb2
29from tensorflow.python.client import session
30from tensorflow.python.debug.lib import debug_data
31from tensorflow.python.debug.wrappers import framework
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import test_util
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import math_ops
39# Import resource_variable_ops for the variables-to-tensor implicit conversion.
40from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
41from tensorflow.python.ops import variables
42from tensorflow.python.platform import googletest
43from tensorflow.python.training import monitored_session
44from tensorflow.python.util import tf_inspect
45
46
47class TestDebugWrapperSession(framework.BaseDebugWrapperSession):
48  """A concrete implementation of BaseDebugWrapperSession for test."""
49
50  def __init__(self, sess, dump_root, observer, thread_name_filter=None):
51    # Supply dump root.
52    self._dump_root = dump_root
53
54    # Supply observer.
55    self._obs = observer
56
57    # Invoke superclass constructor.
58    framework.BaseDebugWrapperSession.__init__(
59        self, sess, thread_name_filter=thread_name_filter)
60
61  def on_session_init(self, request):
62    """Override abstract on-session-init callback method."""
63
64    self._obs["sess_init_count"] += 1
65    self._obs["request_sess"] = request.session
66
67    return framework.OnSessionInitResponse(
68        framework.OnSessionInitAction.PROCEED)
69
70  def on_run_start(self, request):
71    """Override abstract on-run-start callback method."""
72
73    self._obs["on_run_start_count"] += 1
74    self._obs["run_fetches"] = request.fetches
75    self._obs["run_feed_dict"] = request.feed_dict
76
77    return framework.OnRunStartResponse(
78        framework.OnRunStartAction.DEBUG_RUN,
79        ["file://" + self._dump_root])
80
81  def on_run_end(self, request):
82    """Override abstract on-run-end callback method."""
83
84    self._obs["on_run_end_count"] += 1
85    self._obs["performed_action"] = request.performed_action
86    self._obs["tf_error"] = request.tf_error
87
88    return framework.OnRunEndResponse()
89
90
91class TestDebugWrapperSessionBadAction(framework.BaseDebugWrapperSession):
92  """A concrete implementation of BaseDebugWrapperSession for test.
93
94  This class intentionally puts a bad action value in OnSessionInitResponse
95  and/or in OnRunStartAction to test the handling of such invalid cases.
96  """
97
98  def __init__(
99      self,
100      sess,
101      bad_init_action=None,
102      bad_run_start_action=None,
103      bad_debug_urls=None):
104    """Constructor.
105
106    Args:
107      sess: The TensorFlow Session object to be wrapped.
108      bad_init_action: (str) bad action value to be returned during the
109        on-session-init callback.
110      bad_run_start_action: (str) bad action value to be returned during the
111        the on-run-start callback.
112      bad_debug_urls: Bad URL values to be returned during the on-run-start
113        callback.
114    """
115
116    self._bad_init_action = bad_init_action
117    self._bad_run_start_action = bad_run_start_action
118    self._bad_debug_urls = bad_debug_urls
119
120    # Invoke superclass constructor.
121    framework.BaseDebugWrapperSession.__init__(self, sess)
122
123  def on_session_init(self, request):
124    if self._bad_init_action:
125      return framework.OnSessionInitResponse(self._bad_init_action)
126    else:
127      return framework.OnSessionInitResponse(
128          framework.OnSessionInitAction.PROCEED)
129
130  def on_run_start(self, request):
131    debug_urls = self._bad_debug_urls or []
132
133    if self._bad_run_start_action:
134      return framework.OnRunStartResponse(
135          self._bad_run_start_action, debug_urls)
136    else:
137      return framework.OnRunStartResponse(
138          framework.OnRunStartAction.DEBUG_RUN, debug_urls)
139
140  def on_run_end(self, request):
141    return framework.OnRunEndResponse()
142
143
144class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
145
146  def _no_rewrite_session_config(self):
147    rewriter_config = rewriter_config_pb2.RewriterConfig(
148        disable_model_pruning=True)
149    graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
150    return config_pb2.ConfigProto(graph_options=graph_options)
151
152  def setUp(self):
153    self._observer = {
154        "sess_init_count": 0,
155        "request_sess": None,
156        "on_run_start_count": 0,
157        "run_fetches": None,
158        "run_feed_dict": None,
159        "on_run_end_count": 0,
160        "performed_action": None,
161        "tf_error": None,
162    }
163
164    self._dump_root = tempfile.mkdtemp()
165
166    self._sess = session.Session(config=self._no_rewrite_session_config())
167
168    self._a_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
169    self._b_init_val = np.array([[2.0], [-1.0]])
170    self._c_val = np.array([[-4.0], [6.0]])
171
172    self._a_init = constant_op.constant(
173        self._a_init_val, shape=[2, 2], name="a_init")
174    self._b_init = constant_op.constant(
175        self._b_init_val, shape=[2, 1], name="b_init")
176
177    self._ph = array_ops.placeholder(dtype=dtypes.float64, name="ph")
178
179    self._a = variables.Variable(self._a_init, name="a1")
180    self._b = variables.Variable(self._b_init, name="b")
181    self._c = constant_op.constant(self._c_val, shape=[2, 1], name="c")
182
183    # Matrix product of a and b.
184    self._p = math_ops.matmul(self._a, self._b, name="p1")
185
186    # Matrix product of a and ph.
187    self._q = math_ops.matmul(self._a, self._ph, name="q")
188
189    # Sum of two vectors.
190    self._s = math_ops.add(self._p, self._c, name="s")
191
192    # Initialize the variables.
193    self._sess.run(self._a.initializer)
194    self._sess.run(self._b.initializer)
195
196  def tearDown(self):
197    # Tear down temporary dump directory.
198    if os.path.isdir(self._dump_root):
199      shutil.rmtree(self._dump_root)
200
201    ops.reset_default_graph()
202
203  def testSessionInit(self):
204    self.assertEqual(0, self._observer["sess_init_count"])
205
206    wrapper_sess = TestDebugWrapperSession(self._sess, self._dump_root,
207                                           self._observer)
208
209    # Assert that on-session-init callback is invoked.
210    self.assertEqual(1, self._observer["sess_init_count"])
211
212    # Assert that the request to the on-session-init callback carries the
213    # correct session object.
214    self.assertEqual(self._sess, self._observer["request_sess"])
215
216    # Verify that the wrapper session implements the session.SessionInterface.
217    self.assertTrue(isinstance(wrapper_sess, session.SessionInterface))
218    self.assertEqual(self._sess.sess_str, wrapper_sess.sess_str)
219    self.assertEqual(self._sess.graph, wrapper_sess.graph)
220    self.assertEqual(self._sess.graph_def, wrapper_sess.graph_def)
221
222    # Check that the partial_run_setup and partial_run are not implemented for
223    # the debug wrapper session.
224    with self.assertRaises(NotImplementedError):
225      wrapper_sess.partial_run_setup(self._p)
226
227  def testInteractiveSessionInit(self):
228    """The wrapper should work also on other subclasses of session.Session."""
229
230    TestDebugWrapperSession(
231        session.InteractiveSession(), self._dump_root, self._observer)
232
233  def testSessionRun(self):
234    wrapper = TestDebugWrapperSession(
235        self._sess, self._dump_root, self._observer)
236
237    # Check initial state of the observer.
238    self.assertEqual(0, self._observer["on_run_start_count"])
239    self.assertEqual(0, self._observer["on_run_end_count"])
240
241    s = wrapper.run(self._s)
242
243    # Assert the run return value is correct.
244    self.assertAllClose(np.array([[3.0], [4.0]]), s)
245
246    # Assert the on-run-start method is invoked.
247    self.assertEqual(1, self._observer["on_run_start_count"])
248
249    # Assert the on-run-start request reflects the correct fetch.
250    self.assertEqual(self._s, self._observer["run_fetches"])
251
252    # Assert the on-run-start request reflects the correct feed_dict.
253    self.assertIsNone(self._observer["run_feed_dict"])
254
255    # Assert the file debug URL has led to dump on the filesystem.
256    dump = debug_data.DebugDumpDir(self._dump_root)
257    self.assertEqual(7, len(dump.dumped_tensor_data))
258
259    # Assert the on-run-end method is invoked.
260    self.assertEqual(1, self._observer["on_run_end_count"])
261
262    # Assert the performed action field in the on-run-end callback request is
263    # correct.
264    self.assertEqual(
265        framework.OnRunStartAction.DEBUG_RUN,
266        self._observer["performed_action"])
267
268    # No TensorFlow runtime error should have happened.
269    self.assertIsNone(self._observer["tf_error"])
270
271  def testSessionInitInvalidSessionType(self):
272    """Attempt to wrap a non-Session-type object should cause an exception."""
273
274    wrapper = TestDebugWrapperSessionBadAction(self._sess)
275    with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"):
276      TestDebugWrapperSessionBadAction(wrapper)
277
278  def testSessionInitBadActionValue(self):
279    with self.assertRaisesRegexp(
280        ValueError, "Invalid OnSessionInitAction value: nonsense_action"):
281      TestDebugWrapperSessionBadAction(
282          self._sess, bad_init_action="nonsense_action")
283
284  def testRunStartBadActionValue(self):
285    wrapper = TestDebugWrapperSessionBadAction(
286        self._sess, bad_run_start_action="nonsense_action")
287
288    with self.assertRaisesRegexp(
289        ValueError, "Invalid OnRunStartAction value: nonsense_action"):
290      wrapper.run(self._s)
291
292  def testRunStartBadURLs(self):
293    # debug_urls ought to be a list of str, not a str. So an exception should
294    # be raised during a run() call.
295    wrapper = TestDebugWrapperSessionBadAction(
296        self._sess, bad_debug_urls="file://foo")
297
298    with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"):
299      wrapper.run(self._s)
300
301  def testErrorDuringRun(self):
302
303    wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
304                                      self._observer)
305
306    # No matrix size mismatch.
307    self.assertAllClose(
308        np.array([[11.0], [-1.0]]),
309        wrapper.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0]])}))
310    self.assertEqual(1, self._observer["on_run_end_count"])
311    self.assertIsNone(self._observer["tf_error"])
312
313    # Now there should be a matrix size mismatch error.
314    wrapper.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0], [3.0]])})
315    self.assertEqual(2, self._observer["on_run_end_count"])
316    self.assertTrue(
317        isinstance(self._observer["tf_error"], errors.InvalidArgumentError))
318
319  def testUsingWrappedSessionShouldWorkAsContextManager(self):
320    wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
321                                      self._observer)
322
323    with wrapper as sess:
324      self.assertAllClose([[3.0], [4.0]], self._s.eval())
325      self.assertEqual(1, self._observer["on_run_start_count"])
326      self.assertEqual(self._s, self._observer["run_fetches"])
327      self.assertEqual(1, self._observer["on_run_end_count"])
328
329      self.assertAllClose(
330          [[11.0], [-1.0]],
331          sess.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0]])}))
332      self.assertEqual(2, self._observer["on_run_start_count"])
333      self.assertEqual(self._q, self._observer["run_fetches"])
334      self.assertEqual(2, self._observer["on_run_end_count"])
335
336  def testUsingWrappedSessionShouldSupportEvalWithAsDefault(self):
337    wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
338                                      self._observer)
339
340    with wrapper.as_default():
341      foo = constant_op.constant(42, name="foo")
342      self.assertEqual(42, foo.eval())
343      self.assertEqual(foo, self._observer["run_fetches"])
344
345  def testWrapperShouldSupportSessionClose(self):
346    wrapper = TestDebugWrapperSession(self._sess, self._dump_root,
347                                      self._observer)
348    wrapper.close()
349
350  def testWrapperThreadNameFilterMainThread(self):
351    wrapper = TestDebugWrapperSession(
352        self._sess, self._dump_root, self._observer,
353        thread_name_filter="MainThread")
354
355    child_run_output = []
356    def child_thread_job():
357      child_run_output.append(wrapper.run(self._b_init))
358
359    thread = threading.Thread(name="ChildThread", target=child_thread_job)
360    thread.start()
361    self.assertAllClose(self._a_init_val, wrapper.run(self._a_init))
362    thread.join()
363    self.assertAllClose([self._b_init_val], child_run_output)
364
365    dump = debug_data.DebugDumpDir(self._dump_root)
366    self.assertEqual(1, dump.size)
367    self.assertEqual("a_init", dump.dumped_tensor_data[0].node_name)
368
369  def testWrapperThreadNameFilterChildThread(self):
370    wrapper = TestDebugWrapperSession(
371        self._sess, self._dump_root, self._observer,
372        thread_name_filter=r"Child.*")
373
374    child_run_output = []
375    def child_thread_job():
376      child_run_output.append(wrapper.run(self._b_init))
377
378    thread = threading.Thread(name="ChildThread", target=child_thread_job)
379    thread.start()
380    self.assertAllClose(self._a_init_val, wrapper.run(self._a_init))
381    thread.join()
382    self.assertAllClose([self._b_init_val], child_run_output)
383
384    dump = debug_data.DebugDumpDir(self._dump_root)
385    self.assertEqual(1, dump.size)
386    self.assertEqual("b_init", dump.dumped_tensor_data[0].node_name)
387
388  def testWrapperThreadNameFilterBothThreads(self):
389    wrapper = TestDebugWrapperSession(
390        self._sess, self._dump_root, self._observer,
391        thread_name_filter=None)
392
393    child_run_output = []
394    def child_thread_job():
395      child_run_output.append(wrapper.run(self._b_init))
396
397    thread = threading.Thread(name="ChildThread", target=child_thread_job)
398    thread.start()
399    self.assertAllClose(self._a_init_val, wrapper.run(self._a_init))
400    thread.join()
401    self.assertAllClose([self._b_init_val], child_run_output)
402
403    dump = debug_data.DebugDumpDir(self._dump_root, validate=False)
404    self.assertEqual(2, dump.size)
405    self.assertItemsEqual(
406        ["a_init", "b_init"],
407        [datum.node_name for datum in dump.dumped_tensor_data])
408
409
410def _is_public_method_name(method_name):
411  return (method_name.startswith("__") and method_name.endswith("__")
412          or not method_name.startswith("_"))
413
414
415class SessionWrapperPublicMethodParityTest(test_util.TensorFlowTestCase):
416
417  def testWrapperHasAllPublicMethodsOfSession(self):
418    session_public_methods = [
419        method_tuple[0] for method_tuple in
420        tf_inspect.getmembers(session.Session, predicate=tf_inspect.ismethod)
421        if _is_public_method_name(method_tuple[0])]
422    wrapper_public_methods = [
423        method_tuple[0] for method_tuple in
424        tf_inspect.getmembers(
425            framework.BaseDebugWrapperSession, predicate=tf_inspect.ismethod)
426        if _is_public_method_name(method_tuple[0])]
427    missing_public_methods = [
428        method for method in session_public_methods
429        if method not in wrapper_public_methods]
430    self.assertFalse(missing_public_methods)
431
432  def testWrapperHasAllPublicMethodsOfMonitoredSession(self):
433    session_public_methods = [
434        method_tuple[0] for method_tuple in
435        tf_inspect.getmembers(monitored_session.MonitoredSession,
436                              predicate=tf_inspect.ismethod)
437        if _is_public_method_name(method_tuple[0])]
438    wrapper_public_methods = [
439        method_tuple[0] for method_tuple in
440        tf_inspect.getmembers(
441            framework.BaseDebugWrapperSession, predicate=tf_inspect.ismethod)
442        if _is_public_method_name(method_tuple[0])]
443    missing_public_methods = [
444        method for method in session_public_methods
445        if method not in wrapper_public_methods]
446    self.assertFalse(missing_public_methods)
447
448
449if __name__ == "__main__":
450  googletest.main()
451