1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""gRPC debug server in Python."""
16# pylint: disable=g-bad-import-order
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import json
23import threading
24import time
25
26from concurrent import futures
27import grpc
28from six.moves import queue
29
30from tensorflow.core.debug import debug_service_pb2
31from tensorflow.core.framework import graph_pb2
32from tensorflow.python.debug.lib import debug_graphs
33from tensorflow.python.debug.lib import debug_service_pb2_grpc
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.util import compat
36
37DebugWatch = collections.namedtuple("DebugWatch",
38                                    ["node_name", "output_slot", "debug_op"])
39
40
41def _state_change(new_state, node_name, output_slot, debug_op):
42  state_change = debug_service_pb2.EventReply.DebugOpStateChange()
43  state_change.state = new_state
44  state_change.node_name = node_name
45  state_change.output_slot = output_slot
46  state_change.debug_op = debug_op
47  return state_change
48
49
50class EventListenerBaseStreamHandler(object):
51  """Per-stream handler of EventListener gRPC streams."""
52
53  def __init__(self):
54    """Constructor of EventListenerBaseStreamHandler."""
55
56  def on_core_metadata_event(self, event):
57    """Callback for core metadata.
58
59    Args:
60      event: The Event proto that carries a JSON string in its
61        `log_message.message` field.
62
63    Returns:
64      `None` or an `EventReply` proto to be sent back to the client. If `None`,
65      an `EventReply` proto construct with the default no-arg constructor will
66      be sent back to the client.
67    """
68    raise NotImplementedError(
69        "on_core_metadata_event() is not implemented in the base servicer "
70        "class")
71
72  def on_graph_def(self, graph_def, device_name, wall_time):
73    """Callback for Event proto received through the gRPC stream.
74
75    This Event proto carries a GraphDef, encoded as bytes, in its graph_def
76    field.
77
78    Args:
79      graph_def: A GraphDef object.
80      device_name: Name of the device on which the graph was created.
81      wall_time: An epoch timestamp (in microseconds) for the graph.
82
83    Returns:
84      `None` or an `EventReply` proto to be sent back to the client. If `None`,
85      an `EventReply` proto construct with the default no-arg constructor will
86      be sent back to the client.
87    """
88    raise NotImplementedError(
89        "on_graph_def() is not implemented in the base servicer class")
90
91  def on_value_event(self, event):
92    """Callback for Event proto received through the gRPC stream.
93
94    This Event proto carries a Tensor in its summary.value[0] field.
95
96    Args:
97      event: The Event proto from the stream to be processed.
98    """
99    raise NotImplementedError(
100        "on_value_event() is not implemented in the base servicer class")
101
102
103class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
104  """Base Python class for gRPC debug server."""
105
106  def __init__(self, server_port, stream_handler_class):
107    """Constructor.
108
109    Args:
110      server_port: (int) Port number to bind to.
111      stream_handler_class: A class of the base class
112        `EventListenerBaseStreamHandler` that will be used to constructor
113        stream handler objects during `SendEvents` calls.
114    """
115
116    self._server_port = server_port
117    self._stream_handler_class = stream_handler_class
118
119    self._server_lock = threading.Lock()
120    self._server_started = False
121    self._stop_requested = False
122
123    self._debug_ops_state_change_queue = queue.Queue()
124    self._gated_grpc_debug_watches = set()
125    self._breakpoints = set()
126
127  def SendEvents(self, request_iterator, context):
128    """Implementation of the SendEvents service method.
129
130    This method receives streams of Event protos from the client, and processes
131    them in ways specified in the on_event() callback. The stream is
132    bi-directional, but currently only the client-to-server stream (i.e., the
133    stream from the debug ops to the server) is used.
134
135    Args:
136      request_iterator: The incoming stream of Event protos.
137      context: Server context.
138
139    Raises:
140      ValueError: If there are more than one core metadata events.
141
142    Yields:
143      An empty stream of responses.
144    """
145    core_metadata_count = 0
146
147    # A map from GraphDef hash to a list of received chunks.
148    graph_def_chunks = {}
149    tensor_chunks = {}
150
151    stream_handler = None
152    for event in request_iterator:
153      if not stream_handler:
154        stream_handler = self._stream_handler_class()
155
156      if event.summary and event.summary.value:
157        # An Event proto carrying a tensor value.
158        maybe_tensor_event = self._process_tensor_event_in_chunks(
159            event, tensor_chunks)
160        if maybe_tensor_event:
161          event_reply = stream_handler.on_value_event(maybe_tensor_event)
162          if event_reply is not None:
163            yield self._process_debug_op_state_changes(event_reply)
164      else:
165        # Non-tensor-value Event.
166        if event.graph_def:
167          # GraphDef-carrying Event.
168          maybe_graph_def, maybe_device_name, maybe_wall_time = (
169              self._process_encoded_graph_def_in_chunks(
170                  event, graph_def_chunks))
171          if maybe_graph_def:
172            reply = stream_handler.on_graph_def(
173                maybe_graph_def, maybe_device_name, maybe_wall_time)
174            yield self._process_debug_op_state_changes(reply)
175        elif event.log_message.message:
176          # Core metadata-carrying Event.
177          core_metadata_count += 1
178          if core_metadata_count > 1:
179            raise ValueError(
180                "Expected one core metadata event; received multiple")
181          reply = stream_handler.on_core_metadata_event(event)
182          yield self._process_debug_op_state_changes(reply)
183
184  def _process_debug_op_state_changes(self, event_reply=None):
185    """Dequeue and process all the queued debug-op state change protos.
186
187    Include all the debug-op state change protos in a `EventReply` proto.
188
189    Args:
190      event_reply: An `EventReply` to add the `DebugOpStateChange` protos to,
191        or `None`.
192
193    Returns:
194      An `EventReply` proto with the dequeued `DebugOpStateChange` protos (if
195        any) added.
196    """
197    if event_reply is None:
198      event_reply = debug_service_pb2.EventReply()
199    while not self._debug_ops_state_change_queue.empty():
200      state_change = self._debug_ops_state_change_queue.get()
201      debug_node_key = (state_change.node_name, state_change.output_slot,
202                        state_change.debug_op)
203      if (state_change.state ==
204          debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE):
205        logging.info("Adding breakpoint %s:%d:%s", state_change.node_name,
206                     state_change.output_slot, state_change.debug_op)
207        self._breakpoints.add(debug_node_key)
208      elif (state_change.state ==
209            debug_service_pb2.EventReply.DebugOpStateChange.READ_ONLY):
210        logging.info("Adding watchpoint %s:%d:%s", state_change.node_name,
211                     state_change.output_slot, state_change.debug_op)
212        if debug_node_key in self._breakpoints:
213          self._breakpoints.discard(debug_node_key)
214      elif (state_change.state ==
215            debug_service_pb2.EventReply.DebugOpStateChange.DISABLED):
216        logging.info("Removing watchpoint or breakpoint: %s:%d:%s",
217                     state_change.node_name, state_change.output_slot,
218                     state_change.debug_op)
219        if debug_node_key in self._breakpoints:
220          self._breakpoints.discard(debug_node_key)
221        else:
222          logging.warn(
223              "Attempting to remove a non-existent debug node key: %s",
224              debug_node_key)
225      new_state_change = event_reply.debug_op_state_changes.add()
226      new_state_change.CopyFrom(state_change)
227    return event_reply
228
229  def _process_tensor_event_in_chunks(self, event, tensor_chunks):
230    """Possibly reassemble event chunks.
231
232    Due to gRPC's message size limit, a large tensor can be encapsulated in
233    multiple Event proto chunks to be sent through the debugger stream. This
234    method keeps track of the chunks that have arrived, reassemble all chunks
235    corresponding to a tensor when they have arrived and return the reassembled
236    Event proto.
237
238    Args:
239      event: The single Event proto that has arrived.
240      tensor_chunks: A dict used to keep track of the Event protos that have
241        arrived but haven't been reassembled.
242
243    Returns:
244      If all Event protos corresponding to a tensor have arrived, returns the
245      reassembled Event proto. Otherwise, return None.
246    """
247
248    value = event.summary.value[0]
249    debugger_plugin_metadata = json.loads(
250        compat.as_text(value.metadata.plugin_data.content))
251    device_name = debugger_plugin_metadata["device"]
252    num_chunks = debugger_plugin_metadata["numChunks"]
253    chunk_index = debugger_plugin_metadata["chunkIndex"]
254
255    if num_chunks <= 1:
256      return event
257
258    debug_node_name = value.node_name
259    timestamp = int(event.wall_time)
260    tensor_key = "%s_%s_%d" % (device_name, debug_node_name, timestamp)
261
262    if tensor_key not in tensor_chunks:
263      tensor_chunks[tensor_key] = [None] * num_chunks
264
265    chunks = tensor_chunks[tensor_key]
266    if value.tensor.tensor_content:
267      chunks[chunk_index] = value.tensor
268    elif value.tensor.string_val:
269      chunks[chunk_index] = event
270
271    if None not in chunks:
272      if value.tensor.tensor_content:
273        event.summary.value[0].tensor.tensor_content = b"".join(
274            chunk.tensor_content for chunk in chunks)
275        del tensor_chunks[tensor_key]
276        return event
277      elif value.tensor.string_val:
278        merged_event = chunks[0]
279        for chunk in chunks[1:]:
280          merged_event.summary.value[0].tensor.string_val.extend(
281              list(chunk.summary.value[0].tensor.string_val))
282        return merged_event
283
284  def _process_encoded_graph_def_in_chunks(self,
285                                           event,
286                                           graph_def_chunks):
287    """Process an Event proto containing a chunk of encoded GraphDef.
288
289    Args:
290      event: the Event proto containing the chunk of encoded GraphDef.
291      graph_def_chunks: A dict mapping keys for GraphDefs (i.e.,
292      "<graph_def_hash>,<device_name>,<wall_time>") to a list of chunks of
293      encoded GraphDefs.
294
295    Returns:
296      If all chunks of the GraphDef have arrived,
297        return decoded GraphDef proto, device name, wall_time.
298      Otherwise,
299        return None, None, None.
300    """
301    graph_def = graph_pb2.GraphDef()
302    index_bar_0 = event.graph_def.find(b"|")
303    index_bar_1 = event.graph_def.find(b"|", index_bar_0 + 1)
304    index_bar_2 = event.graph_def.find(b"|", index_bar_1 + 1)
305    graph_def_hash_device_timestamp = event.graph_def[:index_bar_0]
306    chunk_index = int(event.graph_def[index_bar_0 + 1 : index_bar_1])
307    num_chunks = int(event.graph_def[index_bar_1 + 1 : index_bar_2])
308    if graph_def_hash_device_timestamp not in graph_def_chunks:
309      graph_def_chunks[graph_def_hash_device_timestamp] = [None] * num_chunks
310    graph_def_chunks[graph_def_hash_device_timestamp][
311        chunk_index] = event.graph_def[index_bar_2 + 1:]
312    if all(graph_def_chunks[graph_def_hash_device_timestamp]):
313      device_name = graph_def_hash_device_timestamp.split(b",")[1]
314      wall_time = int(graph_def_hash_device_timestamp.split(b",")[2])
315      graph_def.ParseFromString(
316          b"".join(graph_def_chunks[graph_def_hash_device_timestamp]))
317      del graph_def_chunks[graph_def_hash_device_timestamp]
318      self._process_graph_def(graph_def)
319      return graph_def, device_name, wall_time
320    else:
321      return None, None, None
322
323  def _process_graph_def(self, graph_def):
324    for node_def in graph_def.node:
325      if (debug_graphs.is_debug_node(node_def.name) and
326          node_def.attr["gated_grpc"].b):
327        node_name, output_slot, _, debug_op = (
328            debug_graphs.parse_debug_node_name(node_def.name))
329        self._gated_grpc_debug_watches.add(
330            DebugWatch(node_name, output_slot, debug_op))
331
332  def run_server(self, blocking=True):
333    """Start running the server.
334
335    Args:
336      blocking: If `True`, block until `stop_server()` is invoked.
337
338    Raises:
339      ValueError: If server stop has already been requested, or if the server
340        has already started running.
341    """
342    self._server_lock.acquire()
343    try:
344      if self._stop_requested:
345        raise ValueError("Server has already stopped")
346      if self._server_started:
347        raise ValueError("Server has already started running")
348
349      self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
350      debug_service_pb2_grpc.add_EventListenerServicer_to_server(self,
351                                                                 self.server)
352      self.server.add_insecure_port("[::]:%d" % self._server_port)
353      self.server.start()
354      self._server_started = True
355    finally:
356      self._server_lock.release()
357
358    if blocking:
359      while not self._stop_requested:
360        time.sleep(1.0)
361
362  def stop_server(self, grace=1.0):
363    """Request server stopping.
364
365    Once stopped, server cannot be stopped or started again. This method is
366    non-blocking. Call `wait()` on the returned event to block until the server
367    has completely stopped.
368
369    Args:
370      grace: Grace period in seconds to be used when calling `server.stop()`.
371
372    Raises:
373      ValueError: If server stop has already been requested, or if the server
374        has not started running yet.
375
376    Returns:
377      A threading.Event that will be set when the server has completely stopped.
378    """
379    self._server_lock.acquire()
380    try:
381      if not self._server_started:
382        raise ValueError("Server has not started running")
383      if self._stop_requested:
384        raise ValueError("Server has already stopped")
385
386      self._stop_requested = True
387      return self.server.stop(grace=grace)
388    finally:
389      self._server_lock.release()
390
391  def request_watch(self, node_name, output_slot, debug_op, breakpoint=False):
392    """Request enabling a debug tensor watchpoint or breakpoint.
393
394    This will let the server send a EventReply to the client side
395    (i.e., the debugged TensorFlow runtime process) to request adding a watch
396    key (i.e., <node_name>:<output_slot>:<debug_op>) to the list of enabled
397    watch keys. The list applies only to debug ops with the attribute
398    gated_grpc=True.
399
400    To disable the watch, use `request_unwatch()`.
401
402    Args:
403      node_name: (`str`) name of the node that the to-be-watched tensor belongs
404        to, e.g., "hidden/Weights".
405      output_slot: (`int`) output slot index of the tensor to watch.
406      debug_op: (`str`) name of the debug op to enable. This should not include
407        any attribute substrings.
408      breakpoint: (`bool`) Iff `True`, the debug op will block and wait until it
409        receives an `EventReply` response from the server. The `EventReply`
410        proto may carry a TensorProto that modifies the value of the debug op's
411        output tensor.
412    """
413    self._debug_ops_state_change_queue.put(
414        _state_change(
415            debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE
416            if breakpoint
417            else debug_service_pb2.EventReply.DebugOpStateChange.READ_ONLY,
418            node_name, output_slot, debug_op))
419
420  def request_unwatch(self, node_name, output_slot, debug_op):
421    """Request disabling a debug tensor watchpoint or breakpoint.
422
423    This is the opposite of `request_watch()`.
424
425    Args:
426      node_name: (`str`) name of the node that the to-be-watched tensor belongs
427        to, e.g., "hidden/Weights".
428      output_slot: (`int`) output slot index of the tensor to watch.
429      debug_op: (`str`) name of the debug op to enable. This should not include
430        any attribute substrings.
431    """
432    self._debug_ops_state_change_queue.put(
433        _state_change(
434            debug_service_pb2.EventReply.DebugOpStateChange.DISABLED, node_name,
435            output_slot, debug_op))
436
437  @property
438  def breakpoints(self):
439    """Get a set of the currently-activated breakpoints.
440
441    Returns:
442      A `set` of 3-tuples: (node_name, output_slot, debug_op), e.g.,
443        {("MatMul", 0, "DebugIdentity")}.
444    """
445    return self._breakpoints
446
447  def gated_grpc_debug_watches(self):
448    """Get the list of debug watches with attribute gated_grpc=True.
449
450    Since the server receives `GraphDef` from the debugged runtime, it can only
451    return such debug watches that it has received so far.
452
453    Returns:
454      A `list` of `DebugWatch` `namedtuples` representing the debug watches with
455      gated_grpc=True. Each `namedtuple` element has the attributes:
456        `node_name` as a `str`,
457        `output_slot` as an `int`,
458        `debug_op` as a `str`.
459    """
460    return list(self._gated_grpc_debug_watches)
461
462  def SendTracebacks(self, request, context):
463    """Base implementation of the handling of SendTracebacks calls.
464
465    The base implementation does nothing with the incoming request.
466    Override in an implementation of the server if necessary.
467
468    Args:
469      request: A `CallTraceback` proto, containing information about the
470        type (e.g., graph vs. eager execution) and source-code traceback of the
471        call and (any) associated `tf.Graph`s.
472      context: Server context.
473
474    Returns:
475      A `EventReply` proto.
476    """
477    return debug_service_pb2.EventReply()
478
479  def SendSourceFiles(self, request, context):
480    """Base implementation of the handling of SendSourceFiles calls.
481
482    The base implementation does nothing with the incoming request.
483    Override in an implementation of the server if necessary.
484
485    Args:
486      request: A `DebuggedSourceFiles` proto, containing the path, content, size
487        and last-modified timestamp of source files.
488      context: Server context.
489
490    Returns:
491      A `EventReply` proto.
492    """
493    return debug_service_pb2.EventReply()
494