1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""gRPC debug server in Python.""" 16# pylint: disable=g-bad-import-order 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import json 23import threading 24import time 25 26from concurrent import futures 27import grpc 28from six.moves import queue 29 30from tensorflow.core.debug import debug_service_pb2 31from tensorflow.core.framework import graph_pb2 32from tensorflow.python.debug.lib import debug_graphs 33from tensorflow.python.debug.lib import debug_service_pb2_grpc 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.util import compat 36 37DebugWatch = collections.namedtuple("DebugWatch", 38 ["node_name", "output_slot", "debug_op"]) 39 40 41def _state_change(new_state, node_name, output_slot, debug_op): 42 state_change = debug_service_pb2.EventReply.DebugOpStateChange() 43 state_change.state = new_state 44 state_change.node_name = node_name 45 state_change.output_slot = output_slot 46 state_change.debug_op = debug_op 47 return state_change 48 49 50class EventListenerBaseStreamHandler(object): 51 """Per-stream handler of EventListener gRPC streams.""" 52 53 def __init__(self): 54 """Constructor of EventListenerBaseStreamHandler.""" 55 56 def on_core_metadata_event(self, event): 57 """Callback for core metadata. 58 59 Args: 60 event: The Event proto that carries a JSON string in its 61 `log_message.message` field. 62 63 Returns: 64 `None` or an `EventReply` proto to be sent back to the client. If `None`, 65 an `EventReply` proto construct with the default no-arg constructor will 66 be sent back to the client. 67 """ 68 raise NotImplementedError( 69 "on_core_metadata_event() is not implemented in the base servicer " 70 "class") 71 72 def on_graph_def(self, graph_def, device_name, wall_time): 73 """Callback for Event proto received through the gRPC stream. 74 75 This Event proto carries a GraphDef, encoded as bytes, in its graph_def 76 field. 77 78 Args: 79 graph_def: A GraphDef object. 80 device_name: Name of the device on which the graph was created. 81 wall_time: An epoch timestamp (in microseconds) for the graph. 82 83 Returns: 84 `None` or an `EventReply` proto to be sent back to the client. If `None`, 85 an `EventReply` proto construct with the default no-arg constructor will 86 be sent back to the client. 87 """ 88 raise NotImplementedError( 89 "on_graph_def() is not implemented in the base servicer class") 90 91 def on_value_event(self, event): 92 """Callback for Event proto received through the gRPC stream. 93 94 This Event proto carries a Tensor in its summary.value[0] field. 95 96 Args: 97 event: The Event proto from the stream to be processed. 98 """ 99 raise NotImplementedError( 100 "on_value_event() is not implemented in the base servicer class") 101 102 103class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer): 104 """Base Python class for gRPC debug server.""" 105 106 def __init__(self, server_port, stream_handler_class): 107 """Constructor. 108 109 Args: 110 server_port: (int) Port number to bind to. 111 stream_handler_class: A class of the base class 112 `EventListenerBaseStreamHandler` that will be used to constructor 113 stream handler objects during `SendEvents` calls. 114 """ 115 116 self._server_port = server_port 117 self._stream_handler_class = stream_handler_class 118 119 self._server_lock = threading.Lock() 120 self._server_started = False 121 self._stop_requested = False 122 123 self._debug_ops_state_change_queue = queue.Queue() 124 self._gated_grpc_debug_watches = set() 125 self._breakpoints = set() 126 127 def SendEvents(self, request_iterator, context): 128 """Implementation of the SendEvents service method. 129 130 This method receives streams of Event protos from the client, and processes 131 them in ways specified in the on_event() callback. The stream is 132 bi-directional, but currently only the client-to-server stream (i.e., the 133 stream from the debug ops to the server) is used. 134 135 Args: 136 request_iterator: The incoming stream of Event protos. 137 context: Server context. 138 139 Raises: 140 ValueError: If there are more than one core metadata events. 141 142 Yields: 143 An empty stream of responses. 144 """ 145 core_metadata_count = 0 146 147 # A map from GraphDef hash to a list of received chunks. 148 graph_def_chunks = {} 149 tensor_chunks = {} 150 151 stream_handler = None 152 for event in request_iterator: 153 if not stream_handler: 154 stream_handler = self._stream_handler_class() 155 156 if event.summary and event.summary.value: 157 # An Event proto carrying a tensor value. 158 maybe_tensor_event = self._process_tensor_event_in_chunks( 159 event, tensor_chunks) 160 if maybe_tensor_event: 161 event_reply = stream_handler.on_value_event(maybe_tensor_event) 162 if event_reply is not None: 163 yield self._process_debug_op_state_changes(event_reply) 164 else: 165 # Non-tensor-value Event. 166 if event.graph_def: 167 # GraphDef-carrying Event. 168 maybe_graph_def, maybe_device_name, maybe_wall_time = ( 169 self._process_encoded_graph_def_in_chunks( 170 event, graph_def_chunks)) 171 if maybe_graph_def: 172 reply = stream_handler.on_graph_def( 173 maybe_graph_def, maybe_device_name, maybe_wall_time) 174 yield self._process_debug_op_state_changes(reply) 175 elif event.log_message.message: 176 # Core metadata-carrying Event. 177 core_metadata_count += 1 178 if core_metadata_count > 1: 179 raise ValueError( 180 "Expected one core metadata event; received multiple") 181 reply = stream_handler.on_core_metadata_event(event) 182 yield self._process_debug_op_state_changes(reply) 183 184 def _process_debug_op_state_changes(self, event_reply=None): 185 """Dequeue and process all the queued debug-op state change protos. 186 187 Include all the debug-op state change protos in a `EventReply` proto. 188 189 Args: 190 event_reply: An `EventReply` to add the `DebugOpStateChange` protos to, 191 or `None`. 192 193 Returns: 194 An `EventReply` proto with the dequeued `DebugOpStateChange` protos (if 195 any) added. 196 """ 197 if event_reply is None: 198 event_reply = debug_service_pb2.EventReply() 199 while not self._debug_ops_state_change_queue.empty(): 200 state_change = self._debug_ops_state_change_queue.get() 201 debug_node_key = (state_change.node_name, state_change.output_slot, 202 state_change.debug_op) 203 if (state_change.state == 204 debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE): 205 logging.info("Adding breakpoint %s:%d:%s", state_change.node_name, 206 state_change.output_slot, state_change.debug_op) 207 self._breakpoints.add(debug_node_key) 208 elif (state_change.state == 209 debug_service_pb2.EventReply.DebugOpStateChange.READ_ONLY): 210 logging.info("Adding watchpoint %s:%d:%s", state_change.node_name, 211 state_change.output_slot, state_change.debug_op) 212 if debug_node_key in self._breakpoints: 213 self._breakpoints.discard(debug_node_key) 214 elif (state_change.state == 215 debug_service_pb2.EventReply.DebugOpStateChange.DISABLED): 216 logging.info("Removing watchpoint or breakpoint: %s:%d:%s", 217 state_change.node_name, state_change.output_slot, 218 state_change.debug_op) 219 if debug_node_key in self._breakpoints: 220 self._breakpoints.discard(debug_node_key) 221 else: 222 logging.warn( 223 "Attempting to remove a non-existent debug node key: %s", 224 debug_node_key) 225 new_state_change = event_reply.debug_op_state_changes.add() 226 new_state_change.CopyFrom(state_change) 227 return event_reply 228 229 def _process_tensor_event_in_chunks(self, event, tensor_chunks): 230 """Possibly reassemble event chunks. 231 232 Due to gRPC's message size limit, a large tensor can be encapsulated in 233 multiple Event proto chunks to be sent through the debugger stream. This 234 method keeps track of the chunks that have arrived, reassemble all chunks 235 corresponding to a tensor when they have arrived and return the reassembled 236 Event proto. 237 238 Args: 239 event: The single Event proto that has arrived. 240 tensor_chunks: A dict used to keep track of the Event protos that have 241 arrived but haven't been reassembled. 242 243 Returns: 244 If all Event protos corresponding to a tensor have arrived, returns the 245 reassembled Event proto. Otherwise, return None. 246 """ 247 248 value = event.summary.value[0] 249 debugger_plugin_metadata = json.loads( 250 compat.as_text(value.metadata.plugin_data.content)) 251 device_name = debugger_plugin_metadata["device"] 252 num_chunks = debugger_plugin_metadata["numChunks"] 253 chunk_index = debugger_plugin_metadata["chunkIndex"] 254 255 if num_chunks <= 1: 256 return event 257 258 debug_node_name = value.node_name 259 timestamp = int(event.wall_time) 260 tensor_key = "%s_%s_%d" % (device_name, debug_node_name, timestamp) 261 262 if tensor_key not in tensor_chunks: 263 tensor_chunks[tensor_key] = [None] * num_chunks 264 265 chunks = tensor_chunks[tensor_key] 266 if value.tensor.tensor_content: 267 chunks[chunk_index] = value.tensor 268 elif value.tensor.string_val: 269 chunks[chunk_index] = event 270 271 if None not in chunks: 272 if value.tensor.tensor_content: 273 event.summary.value[0].tensor.tensor_content = b"".join( 274 chunk.tensor_content for chunk in chunks) 275 del tensor_chunks[tensor_key] 276 return event 277 elif value.tensor.string_val: 278 merged_event = chunks[0] 279 for chunk in chunks[1:]: 280 merged_event.summary.value[0].tensor.string_val.extend( 281 list(chunk.summary.value[0].tensor.string_val)) 282 return merged_event 283 284 def _process_encoded_graph_def_in_chunks(self, 285 event, 286 graph_def_chunks): 287 """Process an Event proto containing a chunk of encoded GraphDef. 288 289 Args: 290 event: the Event proto containing the chunk of encoded GraphDef. 291 graph_def_chunks: A dict mapping keys for GraphDefs (i.e., 292 "<graph_def_hash>,<device_name>,<wall_time>") to a list of chunks of 293 encoded GraphDefs. 294 295 Returns: 296 If all chunks of the GraphDef have arrived, 297 return decoded GraphDef proto, device name, wall_time. 298 Otherwise, 299 return None, None, None. 300 """ 301 graph_def = graph_pb2.GraphDef() 302 index_bar_0 = event.graph_def.find(b"|") 303 index_bar_1 = event.graph_def.find(b"|", index_bar_0 + 1) 304 index_bar_2 = event.graph_def.find(b"|", index_bar_1 + 1) 305 graph_def_hash_device_timestamp = event.graph_def[:index_bar_0] 306 chunk_index = int(event.graph_def[index_bar_0 + 1 : index_bar_1]) 307 num_chunks = int(event.graph_def[index_bar_1 + 1 : index_bar_2]) 308 if graph_def_hash_device_timestamp not in graph_def_chunks: 309 graph_def_chunks[graph_def_hash_device_timestamp] = [None] * num_chunks 310 graph_def_chunks[graph_def_hash_device_timestamp][ 311 chunk_index] = event.graph_def[index_bar_2 + 1:] 312 if all(graph_def_chunks[graph_def_hash_device_timestamp]): 313 device_name = graph_def_hash_device_timestamp.split(b",")[1] 314 wall_time = int(graph_def_hash_device_timestamp.split(b",")[2]) 315 graph_def.ParseFromString( 316 b"".join(graph_def_chunks[graph_def_hash_device_timestamp])) 317 del graph_def_chunks[graph_def_hash_device_timestamp] 318 self._process_graph_def(graph_def) 319 return graph_def, device_name, wall_time 320 else: 321 return None, None, None 322 323 def _process_graph_def(self, graph_def): 324 for node_def in graph_def.node: 325 if (debug_graphs.is_debug_node(node_def.name) and 326 node_def.attr["gated_grpc"].b): 327 node_name, output_slot, _, debug_op = ( 328 debug_graphs.parse_debug_node_name(node_def.name)) 329 self._gated_grpc_debug_watches.add( 330 DebugWatch(node_name, output_slot, debug_op)) 331 332 def run_server(self, blocking=True): 333 """Start running the server. 334 335 Args: 336 blocking: If `True`, block until `stop_server()` is invoked. 337 338 Raises: 339 ValueError: If server stop has already been requested, or if the server 340 has already started running. 341 """ 342 self._server_lock.acquire() 343 try: 344 if self._stop_requested: 345 raise ValueError("Server has already stopped") 346 if self._server_started: 347 raise ValueError("Server has already started running") 348 349 self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) 350 debug_service_pb2_grpc.add_EventListenerServicer_to_server(self, 351 self.server) 352 self.server.add_insecure_port("[::]:%d" % self._server_port) 353 self.server.start() 354 self._server_started = True 355 finally: 356 self._server_lock.release() 357 358 if blocking: 359 while not self._stop_requested: 360 time.sleep(1.0) 361 362 def stop_server(self, grace=1.0): 363 """Request server stopping. 364 365 Once stopped, server cannot be stopped or started again. This method is 366 non-blocking. Call `wait()` on the returned event to block until the server 367 has completely stopped. 368 369 Args: 370 grace: Grace period in seconds to be used when calling `server.stop()`. 371 372 Raises: 373 ValueError: If server stop has already been requested, or if the server 374 has not started running yet. 375 376 Returns: 377 A threading.Event that will be set when the server has completely stopped. 378 """ 379 self._server_lock.acquire() 380 try: 381 if not self._server_started: 382 raise ValueError("Server has not started running") 383 if self._stop_requested: 384 raise ValueError("Server has already stopped") 385 386 self._stop_requested = True 387 return self.server.stop(grace=grace) 388 finally: 389 self._server_lock.release() 390 391 def request_watch(self, node_name, output_slot, debug_op, breakpoint=False): 392 """Request enabling a debug tensor watchpoint or breakpoint. 393 394 This will let the server send a EventReply to the client side 395 (i.e., the debugged TensorFlow runtime process) to request adding a watch 396 key (i.e., <node_name>:<output_slot>:<debug_op>) to the list of enabled 397 watch keys. The list applies only to debug ops with the attribute 398 gated_grpc=True. 399 400 To disable the watch, use `request_unwatch()`. 401 402 Args: 403 node_name: (`str`) name of the node that the to-be-watched tensor belongs 404 to, e.g., "hidden/Weights". 405 output_slot: (`int`) output slot index of the tensor to watch. 406 debug_op: (`str`) name of the debug op to enable. This should not include 407 any attribute substrings. 408 breakpoint: (`bool`) Iff `True`, the debug op will block and wait until it 409 receives an `EventReply` response from the server. The `EventReply` 410 proto may carry a TensorProto that modifies the value of the debug op's 411 output tensor. 412 """ 413 self._debug_ops_state_change_queue.put( 414 _state_change( 415 debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE 416 if breakpoint 417 else debug_service_pb2.EventReply.DebugOpStateChange.READ_ONLY, 418 node_name, output_slot, debug_op)) 419 420 def request_unwatch(self, node_name, output_slot, debug_op): 421 """Request disabling a debug tensor watchpoint or breakpoint. 422 423 This is the opposite of `request_watch()`. 424 425 Args: 426 node_name: (`str`) name of the node that the to-be-watched tensor belongs 427 to, e.g., "hidden/Weights". 428 output_slot: (`int`) output slot index of the tensor to watch. 429 debug_op: (`str`) name of the debug op to enable. This should not include 430 any attribute substrings. 431 """ 432 self._debug_ops_state_change_queue.put( 433 _state_change( 434 debug_service_pb2.EventReply.DebugOpStateChange.DISABLED, node_name, 435 output_slot, debug_op)) 436 437 @property 438 def breakpoints(self): 439 """Get a set of the currently-activated breakpoints. 440 441 Returns: 442 A `set` of 3-tuples: (node_name, output_slot, debug_op), e.g., 443 {("MatMul", 0, "DebugIdentity")}. 444 """ 445 return self._breakpoints 446 447 def gated_grpc_debug_watches(self): 448 """Get the list of debug watches with attribute gated_grpc=True. 449 450 Since the server receives `GraphDef` from the debugged runtime, it can only 451 return such debug watches that it has received so far. 452 453 Returns: 454 A `list` of `DebugWatch` `namedtuples` representing the debug watches with 455 gated_grpc=True. Each `namedtuple` element has the attributes: 456 `node_name` as a `str`, 457 `output_slot` as an `int`, 458 `debug_op` as a `str`. 459 """ 460 return list(self._gated_grpc_debug_watches) 461 462 def SendTracebacks(self, request, context): 463 """Base implementation of the handling of SendTracebacks calls. 464 465 The base implementation does nothing with the incoming request. 466 Override in an implementation of the server if necessary. 467 468 Args: 469 request: A `CallTraceback` proto, containing information about the 470 type (e.g., graph vs. eager execution) and source-code traceback of the 471 call and (any) associated `tf.Graph`s. 472 context: Server context. 473 474 Returns: 475 A `EventReply` proto. 476 """ 477 return debug_service_pb2.EventReply() 478 479 def SendSourceFiles(self, request, context): 480 """Base implementation of the handling of SendSourceFiles calls. 481 482 The base implementation does nothing with the incoming request. 483 Override in an implementation of the server if necessary. 484 485 Args: 486 request: A `DebuggedSourceFiles` proto, containing the path, content, size 487 and last-modified timestamp of source files. 488 context: Server context. 489 490 Returns: 491 A `EventReply` proto. 492 """ 493 return debug_service_pb2.EventReply() 494