1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Execution Callbacks for Eager Mode.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23import numpy as np 24 25from tensorflow.python import pywrap_tensorflow 26from tensorflow.python.eager import context 27from tensorflow.python.eager import core 28from tensorflow.python.eager import execute 29from tensorflow.python.platform import tf_logging as logging 30 31_DEFAULT_CALLBACK_ACTION = "raise" 32_VALID_CALLBACK_ACTIONS = (None, "ignore", "print", "raise", "warn") 33 34 35# TODO(cais): Consider moving this exception class to errors_impl.py. 36class InfOrNanError(Exception): 37 """Exception for inf and/or nan being present in tensor.""" 38 39 def __init__(self, 40 op_type, 41 op_name, 42 output_index, 43 num_outputs, 44 value): 45 """Constructor of InfOrNanError. 46 47 Args: 48 op_type: Type name of the op that generated the tensor that generated the 49 `inf`(s) or `nan`(s) (e.g., `Div`). 50 op_name: Name of the op that generated the tensor with `inf`(s) or 51 `nan`(s). This name is set by client and can be `None` if it is unset. 52 output_index: The 0-based output index of the tensor that contains 53 `inf`(s) or `nan`(s). 54 num_outputs: Total number of outputs of the operation. 55 value: The tensor value that contains `inf`(s) or `nan`(s). 56 """ 57 self._op_type = op_type 58 self._op_name = op_name 59 self._output_index = output_index 60 self._num_outputs = num_outputs 61 self._value = value 62 63 self._total_count = np.size(value) 64 self._inf_count = np.count_nonzero(np.isinf(value)) 65 self._nan_count = np.count_nonzero(np.isnan(value)) 66 67 super(InfOrNanError, self).__init__(self._get_error_message()) 68 69 def _get_error_message(self): 70 """Get the error message describing this InfOrNanError object.""" 71 name_str = (("'%s'" % self._op_name) if self._op_name is not None 72 else str(self._op_name)) 73 msg = "Output %d of %d of TFE operation %s (name: %s) contains " % ( 74 self._output_index + 1, self._num_outputs, self._op_type, name_str) 75 if self._inf_count and self._nan_count: 76 msg += "%d inf(s) and %d nan(s) " % (self._inf_count, self._nan_count) 77 elif self._inf_count: 78 msg += "%d inf(s) " % self._inf_count 79 else: 80 msg += "%d nan(s) " % self._nan_count 81 msg += "out of a total of %d element(s). Tensor value: %s" % ( 82 self._total_count, self._value) 83 return msg 84 85 @property 86 def op_type(self): 87 return self._op_type 88 89 @property 90 def op_name(self): 91 return self._op_name 92 93 @property 94 def output_index(self): 95 return self._output_index 96 97 @property 98 def num_outputs(self): 99 return self._num_outputs 100 101 @property 102 def value(self): 103 return self._value 104 105 106def inf_nan_callback(op_type, 107 inputs, 108 attrs, 109 outputs, 110 op_name, 111 check_inf=True, 112 check_nan=True, 113 action=_DEFAULT_CALLBACK_ACTION): 114 """An execution callback that checks for `inf`s and `nan`s in output tensors. 115 116 This callback can be used with `tfe.add_execute_callback` to check for invalid 117 numeric values. E.g., 118 ```python 119 tfe.add_execute_callback(tfe.inf_nan_callback) 120 ``` 121 122 Args: 123 op_type: Name of the TFE operation type (e.g., `MatMul`). 124 inputs: The `list` of input tensors to the operation, currently unused by 125 this callback. 126 attrs: Attributes of the TFE operation, as a tuple of alternating attribute 127 names and attribute values. 128 outputs: The `list` of output tensors from the operation, checked by this 129 callback for `inf` and `nan` values. 130 op_name: Name of the TFE operation. This name is set by client and can be 131 `None` if it unset. 132 check_inf: (`bool`) Whether this callback should check for `inf` values in 133 the output tensor values. 134 check_nan: (`bool`) Whether this callback should check for `nan` values in 135 the output tensor values. 136 action: (`str`) Action to be taken by the callback when `inf` or `nan` 137 values are detected. Possible values {"raise", "warn", "print"} 138 `"raise"`: Raise a `InfOrNanError`. 139 `"warn"`: Log a warning using `tf.logging.warn`. 140 `"print"`: Print a message to `sys.stdout`. 141 142 Raises: 143 InfOrNanError: iff `inf` or `nan` values are seen in any of `outputs` and 144 `action` is `"raise"`. 145 ValueError: iff the value of `action` is invalid. 146 """ 147 del attrs, inputs # Not used. 148 149 ctx = context.get_default_context() 150 151 for index, output in enumerate(outputs): 152 if not output.dtype.is_numpy_compatible: 153 continue 154 155 numpy_dtype = output.dtype.as_numpy_dtype 156 if (np.issubdtype(numpy_dtype, np.floating) or 157 np.issubdtype(numpy_dtype, np.complex) or 158 np.issubdtype(numpy_dtype, np.integer)): 159 try: 160 check_numerics_op_attrs = ( 161 "message", "Eager-mode inf/nan check", 162 "T", outputs[0].dtype.as_datatype_enum) 163 # TODO(cais): Consider moving this into execute.py. 164 # pylint: disable=protected-access 165 pywrap_tensorflow.TFE_Py_Execute( 166 ctx._handle, output.device, "CheckNumerics", [output], 167 check_numerics_op_attrs, 1) 168 # pylint: enable=protected-access 169 except core._NotOkStatusException: # pylint: disable=protected-access 170 value = output.numpy() 171 inf_detected = np.any(np.isinf(value)) and check_inf 172 nan_detected = np.any(np.isnan(value)) and check_nan 173 if not inf_detected and not nan_detected: 174 continue 175 176 error = InfOrNanError(op_type, op_name, index, len(outputs), value) 177 if action == "print": 178 print("Warning: %s" % str(error)) 179 elif action == "warn": 180 logging.warn(str(error)) 181 elif action == "raise": 182 raise error 183 else: 184 raise ValueError( 185 "Invalid action for inf_nan_callback: %s. Valid actions are: " 186 "{print | warn | raise}" % action) 187 188 189def inf_callback(op_type, 190 inputs, 191 attrs, 192 outputs, 193 op_name, 194 action=_DEFAULT_CALLBACK_ACTION): 195 """A specialization of `inf_nan_callback` that checks for `inf`s only.""" 196 inf_nan_callback( 197 op_type, 198 inputs, 199 attrs, 200 outputs, 201 op_name, 202 check_inf=True, 203 check_nan=False, 204 action=action) 205 206 207def nan_callback(op_type, 208 inputs, 209 attrs, 210 outputs, 211 op_name, 212 action=_DEFAULT_CALLBACK_ACTION): 213 """A specialization of `inf_nan_callback` that checks for `nan`s only.""" 214 inf_nan_callback( 215 op_type, 216 inputs, 217 attrs, 218 outputs, 219 op_name, 220 check_inf=False, 221 check_nan=True, 222 action=action) 223 224 225def add_execution_callback(callback): 226 """Add an execution callback to the default eager context. 227 228 An execution callback is invoked immediately after an eager operation or 229 function has finished execution, providing access to the op's type, name 230 input and output tensors. Multiple execution callbacks can be added, in 231 which case the callbacks will be invoked in the order in which they are 232 added. To clear all execution callbacks that have been added, use 233 `clear_execution_callbacks()`. 234 235 Example: 236 ```python 237 def print_even_callback(op_type, op_name, attrs, inputs, outputs): 238 # A callback that prints only the even output values. 239 if outputs[0].numpy() % 2 == 0: 240 print("Even output from %s: %s" % (op_name or op_type, outputs)) 241 tfe.add_execution_callback(print_even_callback) 242 243 x = tf.pow(2.0, 3.0) - 3.0 244 y = tf.multiply(x, tf.add(1.0, 5.0)) 245 # When the line above is run, you will see all intermediate outputs that are 246 # even numbers printed to the console. 247 248 tfe.clear_execution_callbacks() 249 ``` 250 251 Args: 252 callback: a callable of the signature 253 `f(op_type, op_name, attrs, inputs, outputs)`. 254 `op_type` is the type of the operation that was just executed (e.g., 255 `MatMul`). 256 `op_name` is the name of the operation that has was just executed. This 257 name is set by the client who created the operation and can be `None` if 258 it is unset. 259 `attrs` contains the attributes of the operation as a `tuple` of 260 alternating attribute name and attribute value. 261 `inputs` is the `list` of input `Tensor`(s) to the op. 262 `outputs` is the `list` of output `Tensor`(s) from the op. 263 Return value(s) from the callback are ignored. 264 """ 265 execute.execute = execute.execute_with_callbacks 266 context.get_default_context().add_post_execution_callback(callback) 267 268 269def clear_execution_callbacks(): 270 """Clear all execution callbacks from the default eager context.""" 271 context.get_default_context().clear_post_execution_callbacks() 272 273 274def seterr(inf_or_nan=None): 275 """Set how abnormal conditions are handled by the default eager context. 276 277 Example: 278 ```python 279 tfe.seterr(inf_or_nan="raise") 280 a = tf.constant(10.0) 281 b = tf.constant(0.0) 282 try: 283 c = a / b # <-- Raises InfOrNanError. 284 except Exception as e: 285 print("Caught Exception: %s" % e) 286 287 tfe.seterr(inf_or_nan="ignore") 288 c = a / b # <-- Does NOT raise exception anymore. 289 ``` 290 291 Args: 292 inf_or_nan: Set action for infinity (`inf`) and NaN (`nan`) values. 293 Possible values: `{"ignore", "print", "raise", "warn"}`. 294 `"ignore"`: take no action when `inf` values appear. 295 `"print"`: print a warning to `stdout`. 296 `"raise"`: raise an `InfOrNanError`. 297 `"warn"`: print a warning using `tf.logging.warn`. 298 A value of `None` leads to no change in the action of the condition. 299 300 Returns: 301 A dictionary of old actions. 302 303 Raises: 304 ValueError: If the value of any keyword arguments is invalid. 305 """ 306 if inf_or_nan not in _VALID_CALLBACK_ACTIONS: 307 raise ValueError( 308 "Invalid action value for inf_or_nan: %s. " 309 "Valid actions are %s." % (inf_or_nan, _VALID_CALLBACK_ACTIONS)) 310 311 old_settings = {"inf_or_nan": "ignore"} 312 default_context = context.get_default_context() 313 314 carryover_callbacks = [] 315 for callback in default_context.post_execution_callbacks: 316 # Check whether the callback is inf_nan_callback or a partial object of 317 # inf_nan_callback. 318 if (callback == inf_nan_callback or 319 isinstance(callback, functools.partial) and 320 callback.func == inf_nan_callback): 321 if callback == inf_nan_callback: 322 old_settings["inf_or_nan"] = _DEFAULT_CALLBACK_ACTION 323 else: 324 old_settings["inf_or_nan"] = callback.keywords.get( 325 "action", _DEFAULT_CALLBACK_ACTION) 326 elif inf_or_nan is not None: 327 carryover_callbacks.append(callback) 328 329 if inf_or_nan is not None: 330 default_context.clear_post_execution_callbacks() 331 for callback in carryover_callbacks: 332 default_context.add_post_execution_callback(callback) 333 if inf_or_nan != "ignore": 334 default_context.add_post_execution_callback( 335 functools.partial(inf_nan_callback, action=inf_or_nan)) 336 337 return old_settings 338