1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Execution Callbacks for Eager Mode."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23import numpy as np
24
25from tensorflow.python import pywrap_tensorflow
26from tensorflow.python.eager import context
27from tensorflow.python.eager import core
28from tensorflow.python.eager import execute
29from tensorflow.python.platform import tf_logging as logging
30
31_DEFAULT_CALLBACK_ACTION = "raise"
32_VALID_CALLBACK_ACTIONS = (None, "ignore", "print", "raise", "warn")
33
34
35# TODO(cais): Consider moving this exception class to errors_impl.py.
36class InfOrNanError(Exception):
37  """Exception for inf and/or nan being present in tensor."""
38
39  def __init__(self,
40               op_type,
41               op_name,
42               output_index,
43               num_outputs,
44               value):
45    """Constructor of InfOrNanError.
46
47    Args:
48      op_type: Type name of the op that generated the tensor that generated the
49        `inf`(s) or `nan`(s) (e.g., `Div`).
50      op_name: Name of the op that generated the tensor with `inf`(s) or
51        `nan`(s). This name is set by client and can be `None` if it is unset.
52      output_index: The 0-based output index of the tensor that contains
53        `inf`(s) or `nan`(s).
54      num_outputs: Total number of outputs of the operation.
55      value: The tensor value that contains `inf`(s) or `nan`(s).
56    """
57    self._op_type = op_type
58    self._op_name = op_name
59    self._output_index = output_index
60    self._num_outputs = num_outputs
61    self._value = value
62
63    self._total_count = np.size(value)
64    self._inf_count = np.count_nonzero(np.isinf(value))
65    self._nan_count = np.count_nonzero(np.isnan(value))
66
67    super(InfOrNanError, self).__init__(self._get_error_message())
68
69  def _get_error_message(self):
70    """Get the error message describing this InfOrNanError object."""
71    name_str = (("'%s'" % self._op_name) if self._op_name is not None
72                else str(self._op_name))
73    msg = "Output %d of %d of TFE operation %s (name: %s) contains " % (
74        self._output_index + 1, self._num_outputs, self._op_type, name_str)
75    if self._inf_count and self._nan_count:
76      msg += "%d inf(s) and %d nan(s) " % (self._inf_count, self._nan_count)
77    elif self._inf_count:
78      msg += "%d inf(s) " % self._inf_count
79    else:
80      msg += "%d nan(s) " % self._nan_count
81    msg += "out of a total of %d element(s). Tensor value: %s" % (
82        self._total_count, self._value)
83    return msg
84
85  @property
86  def op_type(self):
87    return self._op_type
88
89  @property
90  def op_name(self):
91    return self._op_name
92
93  @property
94  def output_index(self):
95    return self._output_index
96
97  @property
98  def num_outputs(self):
99    return self._num_outputs
100
101  @property
102  def value(self):
103    return self._value
104
105
106def inf_nan_callback(op_type,
107                     inputs,
108                     attrs,
109                     outputs,
110                     op_name,
111                     check_inf=True,
112                     check_nan=True,
113                     action=_DEFAULT_CALLBACK_ACTION):
114  """An execution callback that checks for `inf`s and `nan`s in output tensors.
115
116  This callback can be used with `tfe.add_execute_callback` to check for invalid
117  numeric values. E.g.,
118  ```python
119  tfe.add_execute_callback(tfe.inf_nan_callback)
120  ```
121
122  Args:
123    op_type: Name of the TFE operation type (e.g., `MatMul`).
124    inputs: The `list` of input tensors to the operation, currently unused by
125      this callback.
126    attrs: Attributes of the TFE operation, as a tuple of alternating attribute
127      names and attribute values.
128    outputs: The `list` of output tensors from the operation, checked by this
129      callback for `inf` and `nan` values.
130    op_name: Name of the TFE operation. This name is set by client and can be
131      `None` if it unset.
132    check_inf: (`bool`) Whether this callback should check for `inf` values in
133      the output tensor values.
134    check_nan: (`bool`) Whether this callback should check for `nan` values in
135      the output tensor values.
136    action: (`str`) Action to be taken by the callback when `inf` or `nan`
137      values are detected. Possible values {"raise", "warn", "print"}
138      `"raise"`: Raise a `InfOrNanError`.
139      `"warn"`: Log a warning using `tf.logging.warn`.
140      `"print"`: Print a message to `sys.stdout`.
141
142  Raises:
143    InfOrNanError: iff `inf` or `nan` values are seen in any of `outputs` and
144      `action` is `"raise"`.
145    ValueError: iff the value of `action` is invalid.
146  """
147  del attrs, inputs  # Not used.
148
149  ctx = context.get_default_context()
150
151  for index, output in enumerate(outputs):
152    if not output.dtype.is_numpy_compatible:
153      continue
154
155    numpy_dtype = output.dtype.as_numpy_dtype
156    if (np.issubdtype(numpy_dtype, np.floating) or
157        np.issubdtype(numpy_dtype, np.complex) or
158        np.issubdtype(numpy_dtype, np.integer)):
159      try:
160        check_numerics_op_attrs = (
161            "message", "Eager-mode inf/nan check",
162            "T", outputs[0].dtype.as_datatype_enum)
163        # TODO(cais): Consider moving this into execute.py.
164        # pylint: disable=protected-access
165        pywrap_tensorflow.TFE_Py_Execute(
166            ctx._handle, output.device, "CheckNumerics", [output],
167            check_numerics_op_attrs, 1)
168        # pylint: enable=protected-access
169      except core._NotOkStatusException:  # pylint: disable=protected-access
170        value = output.numpy()
171        inf_detected = np.any(np.isinf(value)) and check_inf
172        nan_detected = np.any(np.isnan(value)) and check_nan
173        if not inf_detected and not nan_detected:
174          continue
175
176        error = InfOrNanError(op_type, op_name, index, len(outputs), value)
177        if action == "print":
178          print("Warning: %s" % str(error))
179        elif action == "warn":
180          logging.warn(str(error))
181        elif action == "raise":
182          raise error
183        else:
184          raise ValueError(
185              "Invalid action for inf_nan_callback: %s. Valid actions are: "
186              "{print | warn | raise}" % action)
187
188
189def inf_callback(op_type,
190                 inputs,
191                 attrs,
192                 outputs,
193                 op_name,
194                 action=_DEFAULT_CALLBACK_ACTION):
195  """A specialization of `inf_nan_callback` that checks for `inf`s only."""
196  inf_nan_callback(
197      op_type,
198      inputs,
199      attrs,
200      outputs,
201      op_name,
202      check_inf=True,
203      check_nan=False,
204      action=action)
205
206
207def nan_callback(op_type,
208                 inputs,
209                 attrs,
210                 outputs,
211                 op_name,
212                 action=_DEFAULT_CALLBACK_ACTION):
213  """A specialization of `inf_nan_callback` that checks for `nan`s only."""
214  inf_nan_callback(
215      op_type,
216      inputs,
217      attrs,
218      outputs,
219      op_name,
220      check_inf=False,
221      check_nan=True,
222      action=action)
223
224
225def add_execution_callback(callback):
226  """Add an execution callback to the default eager context.
227
228  An execution callback is invoked immediately after an eager operation or
229  function has finished execution, providing access to the op's type, name
230  input and output tensors. Multiple execution callbacks can be added, in
231  which case the callbacks will be invoked in the order in which they are
232  added. To clear all execution callbacks that have been added, use
233  `clear_execution_callbacks()`.
234
235  Example:
236  ```python
237  def print_even_callback(op_type, op_name, attrs, inputs, outputs):
238    # A callback that prints only the even output values.
239    if outputs[0].numpy() % 2 == 0:
240      print("Even output from %s: %s" % (op_name or op_type,  outputs))
241  tfe.add_execution_callback(print_even_callback)
242
243  x = tf.pow(2.0, 3.0) - 3.0
244  y = tf.multiply(x, tf.add(1.0, 5.0))
245  # When the line above is run, you will see all intermediate outputs that are
246  # even numbers printed to the console.
247
248  tfe.clear_execution_callbacks()
249  ```
250
251  Args:
252    callback: a callable of the signature
253      `f(op_type, op_name, attrs, inputs, outputs)`.
254      `op_type` is the type of the operation that was just executed (e.g.,
255        `MatMul`).
256      `op_name` is the name of the operation that has was just executed. This
257        name is set by the client who created the operation and can be `None` if
258        it is unset.
259      `attrs` contains the attributes of the operation as a `tuple` of
260        alternating attribute name and attribute value.
261      `inputs` is the `list` of input `Tensor`(s) to the op.
262      `outputs` is the `list` of output `Tensor`(s) from the op.
263       Return value(s) from the callback are ignored.
264  """
265  execute.execute = execute.execute_with_callbacks
266  context.get_default_context().add_post_execution_callback(callback)
267
268
269def clear_execution_callbacks():
270  """Clear all execution callbacks from the default eager context."""
271  context.get_default_context().clear_post_execution_callbacks()
272
273
274def seterr(inf_or_nan=None):
275  """Set how abnormal conditions are handled by the default eager context.
276
277  Example:
278  ```python
279  tfe.seterr(inf_or_nan="raise")
280  a = tf.constant(10.0)
281  b = tf.constant(0.0)
282  try:
283    c = a / b  # <-- Raises InfOrNanError.
284  except Exception as e:
285    print("Caught Exception: %s" % e)
286
287  tfe.seterr(inf_or_nan="ignore")
288  c = a / b  # <-- Does NOT raise exception anymore.
289  ```
290
291  Args:
292    inf_or_nan: Set action for infinity (`inf`) and NaN (`nan`) values.
293      Possible values: `{"ignore", "print", "raise", "warn"}`.
294      `"ignore"`: take no action when `inf` values appear.
295      `"print"`: print a warning to `stdout`.
296      `"raise"`: raise an `InfOrNanError`.
297      `"warn"`: print a warning using `tf.logging.warn`.
298      A value of `None` leads to no change in the action of the condition.
299
300  Returns:
301    A dictionary of old actions.
302
303  Raises:
304    ValueError: If the value of any keyword arguments is invalid.
305  """
306  if inf_or_nan not in _VALID_CALLBACK_ACTIONS:
307    raise ValueError(
308        "Invalid action value for inf_or_nan: %s. "
309        "Valid actions are %s." % (inf_or_nan, _VALID_CALLBACK_ACTIONS))
310
311  old_settings = {"inf_or_nan": "ignore"}
312  default_context = context.get_default_context()
313
314  carryover_callbacks = []
315  for callback in default_context.post_execution_callbacks:
316    # Check whether the callback is inf_nan_callback or a partial object of
317    # inf_nan_callback.
318    if (callback == inf_nan_callback or
319        isinstance(callback, functools.partial) and
320        callback.func == inf_nan_callback):
321      if callback == inf_nan_callback:
322        old_settings["inf_or_nan"] = _DEFAULT_CALLBACK_ACTION
323      else:
324        old_settings["inf_or_nan"] = callback.keywords.get(
325            "action", _DEFAULT_CALLBACK_ACTION)
326    elif inf_or_nan is not None:
327      carryover_callbacks.append(callback)
328
329  if inf_or_nan is not None:
330    default_context.clear_post_execution_callbacks()
331    for callback in carryover_callbacks:
332      default_context.add_post_execution_callback(callback)
333    if inf_or_nan != "ignore":
334      default_context.add_post_execution_callback(
335          functools.partial(inf_nan_callback, action=inf_or_nan))
336
337  return old_settings
338