1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Framework of debug-wrapped sessions.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21import shutil 22import tempfile 23import threading 24 25import numpy as np 26 27from tensorflow.core.protobuf import config_pb2 28from tensorflow.core.protobuf import rewriter_config_pb2 29from tensorflow.python.client import session 30from tensorflow.python.debug.lib import debug_data 31from tensorflow.python.debug.wrappers import framework 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import test_util 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import math_ops 39# Import resource_variable_ops for the variables-to-tensor implicit conversion. 40from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 41from tensorflow.python.ops import variables 42from tensorflow.python.platform import googletest 43from tensorflow.python.training import monitored_session 44from tensorflow.python.util import tf_inspect 45 46 47class TestDebugWrapperSession(framework.BaseDebugWrapperSession): 48 """A concrete implementation of BaseDebugWrapperSession for test.""" 49 50 def __init__(self, sess, dump_root, observer, thread_name_filter=None): 51 # Supply dump root. 52 self._dump_root = dump_root 53 54 # Supply observer. 55 self._obs = observer 56 57 # Invoke superclass constructor. 58 framework.BaseDebugWrapperSession.__init__( 59 self, sess, thread_name_filter=thread_name_filter) 60 61 def on_session_init(self, request): 62 """Override abstract on-session-init callback method.""" 63 64 self._obs["sess_init_count"] += 1 65 self._obs["request_sess"] = request.session 66 67 return framework.OnSessionInitResponse( 68 framework.OnSessionInitAction.PROCEED) 69 70 def on_run_start(self, request): 71 """Override abstract on-run-start callback method.""" 72 73 self._obs["on_run_start_count"] += 1 74 self._obs["run_fetches"] = request.fetches 75 self._obs["run_feed_dict"] = request.feed_dict 76 77 return framework.OnRunStartResponse( 78 framework.OnRunStartAction.DEBUG_RUN, 79 ["file://" + self._dump_root]) 80 81 def on_run_end(self, request): 82 """Override abstract on-run-end callback method.""" 83 84 self._obs["on_run_end_count"] += 1 85 self._obs["performed_action"] = request.performed_action 86 self._obs["tf_error"] = request.tf_error 87 88 return framework.OnRunEndResponse() 89 90 91class TestDebugWrapperSessionBadAction(framework.BaseDebugWrapperSession): 92 """A concrete implementation of BaseDebugWrapperSession for test. 93 94 This class intentionally puts a bad action value in OnSessionInitResponse 95 and/or in OnRunStartAction to test the handling of such invalid cases. 96 """ 97 98 def __init__( 99 self, 100 sess, 101 bad_init_action=None, 102 bad_run_start_action=None, 103 bad_debug_urls=None): 104 """Constructor. 105 106 Args: 107 sess: The TensorFlow Session object to be wrapped. 108 bad_init_action: (str) bad action value to be returned during the 109 on-session-init callback. 110 bad_run_start_action: (str) bad action value to be returned during the 111 the on-run-start callback. 112 bad_debug_urls: Bad URL values to be returned during the on-run-start 113 callback. 114 """ 115 116 self._bad_init_action = bad_init_action 117 self._bad_run_start_action = bad_run_start_action 118 self._bad_debug_urls = bad_debug_urls 119 120 # Invoke superclass constructor. 121 framework.BaseDebugWrapperSession.__init__(self, sess) 122 123 def on_session_init(self, request): 124 if self._bad_init_action: 125 return framework.OnSessionInitResponse(self._bad_init_action) 126 else: 127 return framework.OnSessionInitResponse( 128 framework.OnSessionInitAction.PROCEED) 129 130 def on_run_start(self, request): 131 debug_urls = self._bad_debug_urls or [] 132 133 if self._bad_run_start_action: 134 return framework.OnRunStartResponse( 135 self._bad_run_start_action, debug_urls) 136 else: 137 return framework.OnRunStartResponse( 138 framework.OnRunStartAction.DEBUG_RUN, debug_urls) 139 140 def on_run_end(self, request): 141 return framework.OnRunEndResponse() 142 143 144class DebugWrapperSessionTest(test_util.TensorFlowTestCase): 145 146 def _no_rewrite_session_config(self): 147 rewriter_config = rewriter_config_pb2.RewriterConfig( 148 disable_model_pruning=True) 149 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) 150 return config_pb2.ConfigProto(graph_options=graph_options) 151 152 def setUp(self): 153 self._observer = { 154 "sess_init_count": 0, 155 "request_sess": None, 156 "on_run_start_count": 0, 157 "run_fetches": None, 158 "run_feed_dict": None, 159 "on_run_end_count": 0, 160 "performed_action": None, 161 "tf_error": None, 162 } 163 164 self._dump_root = tempfile.mkdtemp() 165 166 self._sess = session.Session(config=self._no_rewrite_session_config()) 167 168 self._a_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]]) 169 self._b_init_val = np.array([[2.0], [-1.0]]) 170 self._c_val = np.array([[-4.0], [6.0]]) 171 172 self._a_init = constant_op.constant( 173 self._a_init_val, shape=[2, 2], name="a_init") 174 self._b_init = constant_op.constant( 175 self._b_init_val, shape=[2, 1], name="b_init") 176 177 self._ph = array_ops.placeholder(dtype=dtypes.float64, name="ph") 178 179 self._a = variables.Variable(self._a_init, name="a1") 180 self._b = variables.Variable(self._b_init, name="b") 181 self._c = constant_op.constant(self._c_val, shape=[2, 1], name="c") 182 183 # Matrix product of a and b. 184 self._p = math_ops.matmul(self._a, self._b, name="p1") 185 186 # Matrix product of a and ph. 187 self._q = math_ops.matmul(self._a, self._ph, name="q") 188 189 # Sum of two vectors. 190 self._s = math_ops.add(self._p, self._c, name="s") 191 192 # Initialize the variables. 193 self._sess.run(self._a.initializer) 194 self._sess.run(self._b.initializer) 195 196 def tearDown(self): 197 # Tear down temporary dump directory. 198 if os.path.isdir(self._dump_root): 199 shutil.rmtree(self._dump_root) 200 201 ops.reset_default_graph() 202 203 def testSessionInit(self): 204 self.assertEqual(0, self._observer["sess_init_count"]) 205 206 wrapper_sess = TestDebugWrapperSession(self._sess, self._dump_root, 207 self._observer) 208 209 # Assert that on-session-init callback is invoked. 210 self.assertEqual(1, self._observer["sess_init_count"]) 211 212 # Assert that the request to the on-session-init callback carries the 213 # correct session object. 214 self.assertEqual(self._sess, self._observer["request_sess"]) 215 216 # Verify that the wrapper session implements the session.SessionInterface. 217 self.assertTrue(isinstance(wrapper_sess, session.SessionInterface)) 218 self.assertEqual(self._sess.sess_str, wrapper_sess.sess_str) 219 self.assertEqual(self._sess.graph, wrapper_sess.graph) 220 self.assertEqual(self._sess.graph_def, wrapper_sess.graph_def) 221 222 # Check that the partial_run_setup and partial_run are not implemented for 223 # the debug wrapper session. 224 with self.assertRaises(NotImplementedError): 225 wrapper_sess.partial_run_setup(self._p) 226 227 def testInteractiveSessionInit(self): 228 """The wrapper should work also on other subclasses of session.Session.""" 229 230 TestDebugWrapperSession( 231 session.InteractiveSession(), self._dump_root, self._observer) 232 233 def testSessionRun(self): 234 wrapper = TestDebugWrapperSession( 235 self._sess, self._dump_root, self._observer) 236 237 # Check initial state of the observer. 238 self.assertEqual(0, self._observer["on_run_start_count"]) 239 self.assertEqual(0, self._observer["on_run_end_count"]) 240 241 s = wrapper.run(self._s) 242 243 # Assert the run return value is correct. 244 self.assertAllClose(np.array([[3.0], [4.0]]), s) 245 246 # Assert the on-run-start method is invoked. 247 self.assertEqual(1, self._observer["on_run_start_count"]) 248 249 # Assert the on-run-start request reflects the correct fetch. 250 self.assertEqual(self._s, self._observer["run_fetches"]) 251 252 # Assert the on-run-start request reflects the correct feed_dict. 253 self.assertIsNone(self._observer["run_feed_dict"]) 254 255 # Assert the file debug URL has led to dump on the filesystem. 256 dump = debug_data.DebugDumpDir(self._dump_root) 257 self.assertEqual(7, len(dump.dumped_tensor_data)) 258 259 # Assert the on-run-end method is invoked. 260 self.assertEqual(1, self._observer["on_run_end_count"]) 261 262 # Assert the performed action field in the on-run-end callback request is 263 # correct. 264 self.assertEqual( 265 framework.OnRunStartAction.DEBUG_RUN, 266 self._observer["performed_action"]) 267 268 # No TensorFlow runtime error should have happened. 269 self.assertIsNone(self._observer["tf_error"]) 270 271 def testSessionInitInvalidSessionType(self): 272 """Attempt to wrap a non-Session-type object should cause an exception.""" 273 274 wrapper = TestDebugWrapperSessionBadAction(self._sess) 275 with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"): 276 TestDebugWrapperSessionBadAction(wrapper) 277 278 def testSessionInitBadActionValue(self): 279 with self.assertRaisesRegexp( 280 ValueError, "Invalid OnSessionInitAction value: nonsense_action"): 281 TestDebugWrapperSessionBadAction( 282 self._sess, bad_init_action="nonsense_action") 283 284 def testRunStartBadActionValue(self): 285 wrapper = TestDebugWrapperSessionBadAction( 286 self._sess, bad_run_start_action="nonsense_action") 287 288 with self.assertRaisesRegexp( 289 ValueError, "Invalid OnRunStartAction value: nonsense_action"): 290 wrapper.run(self._s) 291 292 def testRunStartBadURLs(self): 293 # debug_urls ought to be a list of str, not a str. So an exception should 294 # be raised during a run() call. 295 wrapper = TestDebugWrapperSessionBadAction( 296 self._sess, bad_debug_urls="file://foo") 297 298 with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"): 299 wrapper.run(self._s) 300 301 def testErrorDuringRun(self): 302 303 wrapper = TestDebugWrapperSession(self._sess, self._dump_root, 304 self._observer) 305 306 # No matrix size mismatch. 307 self.assertAllClose( 308 np.array([[11.0], [-1.0]]), 309 wrapper.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0]])})) 310 self.assertEqual(1, self._observer["on_run_end_count"]) 311 self.assertIsNone(self._observer["tf_error"]) 312 313 # Now there should be a matrix size mismatch error. 314 wrapper.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0], [3.0]])}) 315 self.assertEqual(2, self._observer["on_run_end_count"]) 316 self.assertTrue( 317 isinstance(self._observer["tf_error"], errors.InvalidArgumentError)) 318 319 def testUsingWrappedSessionShouldWorkAsContextManager(self): 320 wrapper = TestDebugWrapperSession(self._sess, self._dump_root, 321 self._observer) 322 323 with wrapper as sess: 324 self.assertAllClose([[3.0], [4.0]], self._s.eval()) 325 self.assertEqual(1, self._observer["on_run_start_count"]) 326 self.assertEqual(self._s, self._observer["run_fetches"]) 327 self.assertEqual(1, self._observer["on_run_end_count"]) 328 329 self.assertAllClose( 330 [[11.0], [-1.0]], 331 sess.run(self._q, feed_dict={self._ph: np.array([[1.0], [2.0]])})) 332 self.assertEqual(2, self._observer["on_run_start_count"]) 333 self.assertEqual(self._q, self._observer["run_fetches"]) 334 self.assertEqual(2, self._observer["on_run_end_count"]) 335 336 def testUsingWrappedSessionShouldSupportEvalWithAsDefault(self): 337 wrapper = TestDebugWrapperSession(self._sess, self._dump_root, 338 self._observer) 339 340 with wrapper.as_default(): 341 foo = constant_op.constant(42, name="foo") 342 self.assertEqual(42, foo.eval()) 343 self.assertEqual(foo, self._observer["run_fetches"]) 344 345 def testWrapperShouldSupportSessionClose(self): 346 wrapper = TestDebugWrapperSession(self._sess, self._dump_root, 347 self._observer) 348 wrapper.close() 349 350 def testWrapperThreadNameFilterMainThread(self): 351 wrapper = TestDebugWrapperSession( 352 self._sess, self._dump_root, self._observer, 353 thread_name_filter="MainThread") 354 355 child_run_output = [] 356 def child_thread_job(): 357 child_run_output.append(wrapper.run(self._b_init)) 358 359 thread = threading.Thread(name="ChildThread", target=child_thread_job) 360 thread.start() 361 self.assertAllClose(self._a_init_val, wrapper.run(self._a_init)) 362 thread.join() 363 self.assertAllClose([self._b_init_val], child_run_output) 364 365 dump = debug_data.DebugDumpDir(self._dump_root) 366 self.assertEqual(1, dump.size) 367 self.assertEqual("a_init", dump.dumped_tensor_data[0].node_name) 368 369 def testWrapperThreadNameFilterChildThread(self): 370 wrapper = TestDebugWrapperSession( 371 self._sess, self._dump_root, self._observer, 372 thread_name_filter=r"Child.*") 373 374 child_run_output = [] 375 def child_thread_job(): 376 child_run_output.append(wrapper.run(self._b_init)) 377 378 thread = threading.Thread(name="ChildThread", target=child_thread_job) 379 thread.start() 380 self.assertAllClose(self._a_init_val, wrapper.run(self._a_init)) 381 thread.join() 382 self.assertAllClose([self._b_init_val], child_run_output) 383 384 dump = debug_data.DebugDumpDir(self._dump_root) 385 self.assertEqual(1, dump.size) 386 self.assertEqual("b_init", dump.dumped_tensor_data[0].node_name) 387 388 def testWrapperThreadNameFilterBothThreads(self): 389 wrapper = TestDebugWrapperSession( 390 self._sess, self._dump_root, self._observer, 391 thread_name_filter=None) 392 393 child_run_output = [] 394 def child_thread_job(): 395 child_run_output.append(wrapper.run(self._b_init)) 396 397 thread = threading.Thread(name="ChildThread", target=child_thread_job) 398 thread.start() 399 self.assertAllClose(self._a_init_val, wrapper.run(self._a_init)) 400 thread.join() 401 self.assertAllClose([self._b_init_val], child_run_output) 402 403 dump = debug_data.DebugDumpDir(self._dump_root, validate=False) 404 self.assertEqual(2, dump.size) 405 self.assertItemsEqual( 406 ["a_init", "b_init"], 407 [datum.node_name for datum in dump.dumped_tensor_data]) 408 409 410def _is_public_method_name(method_name): 411 return (method_name.startswith("__") and method_name.endswith("__") 412 or not method_name.startswith("_")) 413 414 415class SessionWrapperPublicMethodParityTest(test_util.TensorFlowTestCase): 416 417 def testWrapperHasAllPublicMethodsOfSession(self): 418 session_public_methods = [ 419 method_tuple[0] for method_tuple in 420 tf_inspect.getmembers(session.Session, predicate=tf_inspect.ismethod) 421 if _is_public_method_name(method_tuple[0])] 422 wrapper_public_methods = [ 423 method_tuple[0] for method_tuple in 424 tf_inspect.getmembers( 425 framework.BaseDebugWrapperSession, predicate=tf_inspect.ismethod) 426 if _is_public_method_name(method_tuple[0])] 427 missing_public_methods = [ 428 method for method in session_public_methods 429 if method not in wrapper_public_methods] 430 self.assertFalse(missing_public_methods) 431 432 def testWrapperHasAllPublicMethodsOfMonitoredSession(self): 433 session_public_methods = [ 434 method_tuple[0] for method_tuple in 435 tf_inspect.getmembers(monitored_session.MonitoredSession, 436 predicate=tf_inspect.ismethod) 437 if _is_public_method_name(method_tuple[0])] 438 wrapper_public_methods = [ 439 method_tuple[0] for method_tuple in 440 tf_inspect.getmembers( 441 framework.BaseDebugWrapperSession, predicate=tf_inspect.ismethod) 442 if _is_public_method_name(method_tuple[0])] 443 missing_public_methods = [ 444 method for method in session_public_methods 445 if method not in wrapper_public_methods] 446 self.assertFalse(missing_public_methods) 447 448 449if __name__ == "__main__": 450 googletest.main() 451