1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Decorator that produces a callable object that executes a TensorFlow graph. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import contextlib 23 24from tensorflow.python.eager import context 25from tensorflow.python.eager import function 26from tensorflow.python.eager import tape 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops as tf_ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import resource_variable_ops 33from tensorflow.python.ops import variable_scope 34from tensorflow.python.util import nest 35from tensorflow.python.util import tf_decorator 36from tensorflow.python.util import tf_inspect 37 38 39def _default_initializer(name, shape, dtype): 40 """The default initializer for variables.""" 41 # pylint: disable=protected-access 42 store = variable_scope._get_default_variable_store() 43 initializer = store._get_default_initializer(name, shape=shape, dtype=dtype) 44 # pylint: enable=protected-access 45 return initializer[0] 46 47 48class _CapturedVariable(object): 49 """Variable captured by graph_callable. 50 51 Internal to the implementation of graph_callable. Created only by 52 _VariableCapturingScope and used only to read the variable values when calling 53 the function after the variables are initialized. 54 """ 55 56 def __init__(self, name, initializer, shape, dtype, trainable): 57 self.name = name 58 if initializer is None: 59 initializer = _default_initializer(name, shape, dtype) 60 initial_value = lambda: initializer(shape, dtype=dtype) 61 62 with context.eager_mode(): 63 self.variable = resource_variable_ops.ResourceVariable( 64 initial_value=initial_value, name=name, dtype=dtype, 65 trainable=trainable) 66 self.shape = shape 67 self.dtype = dtype 68 self.placeholder = None 69 self.trainable = trainable 70 71 def read(self, want_gradients=True): 72 if want_gradients and self.trainable: 73 v = tape.watch_variable(self.variable) 74 else: 75 v = self.variable 76 return v.read_value() 77 78 79class _VariableCapturingScope(object): 80 """Variable-scope-like object which captures tf.get_variable calls. 81 82 This is responsible for the main difference between the initialization version 83 of a function object and the calling version of a function object. 84 85 capturing_scope replaces calls to tf.get_variable with placeholder tensors to 86 be fed the variable's current value. TODO(apassos): these placeholders should 87 instead be objects implementing a similar API to tf.Variable, for full 88 compatibility. 89 90 initializing_scope replaces calls to tf.get_variable with creation of 91 variables and initialization of their values. This allows eventual support of 92 initialized_value and friends. 93 94 TODO(apassos): once the eager mode layers API is implemented support eager 95 func-to-object as well. 96 """ 97 98 def __init__(self): 99 self.variables = {} 100 self.tf_variables = {} 101 102 @contextlib.contextmanager 103 def capturing_scope(self): 104 """Context manager to capture variable creations. 105 106 Replaces variable accesses with placeholders. 107 108 Yields: 109 nothing 110 """ 111 # TODO(apassos) ignoring the regularizer and partitioner here; figure out 112 # how to deal with these. 113 def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring 114 initializer=None, regularizer=None, reuse=None, 115 trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name 116 partitioner=None, validate_shape=True, 117 use_resource=None): 118 del getter, regularizer, partitioner, validate_shape, use_resource, dtype 119 del collections, initializer, trainable, reuse, caching_device, shape, 120 assert name in self.variables 121 v = self.variables[name] 122 return v.variable 123 124 scope = variable_scope.get_variable_scope() 125 with variable_scope.variable_scope(scope, custom_getter=_custom_getter): 126 yield 127 128 @contextlib.contextmanager 129 def initializing_scope(self): 130 """Context manager to capture variable creations. 131 132 Forcibly initializes all created variables. 133 134 Yields: 135 nothing 136 """ 137 # TODO(apassos) ignoring the regularizer and partitioner here; figure out 138 # how to deal with these. 139 def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring 140 initializer=None, regularizer=None, reuse=None, 141 trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name 142 partitioner=None, validate_shape=True, 143 use_resource=None): 144 del getter, regularizer, collections, caching_device, partitioner 145 del use_resource, validate_shape 146 if name in self.tf_variables: 147 if reuse: 148 return self.tf_variables[name].initialized_value() 149 else: 150 raise ValueError("Specified reuse=%s but tried to reuse variables." 151 % reuse) 152 # TODO(apassos): ensure this is on the same device as above 153 v = _CapturedVariable(name, initializer, shape, dtype, trainable) 154 self.variables[name] = v 155 156 graph_mode_resource = v.variable.handle 157 if initializer is None: 158 initializer = _default_initializer(name, shape, dtype) 159 resource_variable_ops.shape_safe_assign_variable_handle( 160 graph_mode_resource, v.variable.shape, initializer(shape, dtype)) 161 return v.variable 162 163 scope = variable_scope.get_variable_scope() 164 with variable_scope.variable_scope(scope, custom_getter=_custom_getter): 165 yield 166 167 168class _InitializingFunctionObject(object): 169 """Responsible for deciding which version of func-to-object to call. 170 171 call_fn is the version which calls the function with the current values of the 172 variables and init_fn is the version which calls the function to initialize 173 all variables. 174 175 TODO(apassos): figure out a way to support initializing only _some_ 176 variables. This requires a way to pull out a variable's initialization code 177 from the graph, which might not be possible in general. 178 """ 179 180 def __init__(self, call_fn, init_fn, shape_and_dtypes): 181 self._init_fn = init_fn 182 self._call_fn = call_fn 183 self.shape_and_dtypes = shape_and_dtypes 184 self.flattened_shapes = [tensor_shape.as_shape(sd.shape) for sd in 185 nest.flatten(self.shape_and_dtypes)] 186 187 @property 188 def variables(self): 189 return self._call_fn.variables 190 191 def __call__(self, *args): 192 nest.assert_same_structure(self.shape_and_dtypes, args, check_types=False) 193 if not all([ 194 shape.is_compatible_with(arg.shape) 195 for shape, arg in zip(self.flattened_shapes, nest.flatten(args)) 196 ]): 197 raise ValueError( 198 "Declared shapes do not match argument shapes: Expected %s, found %s." 199 % (self.flattened_shapes, [arg.shape for arg in nest.flatten(args)])) 200 201 initialized = [resource_variable_ops.var_is_initialized_op( 202 v.handle).numpy() for v in self._call_fn.variables] 203 if all(x for x in initialized): 204 for v in self._call_fn.variables: 205 if v._trainable: # pylint: disable=protected-access 206 tape.watch_variable(v) 207 return self._call_fn(*args) 208 elif all(not x for x in initialized): 209 return self._init_fn(*args) 210 else: 211 raise ValueError("Some, but not all, variables are initialized.") 212 213 214def _get_graph_callable_inputs(shape_and_dtypes): 215 """Maps specified shape_and_dtypes to graph inputs.""" 216 ret = [] 217 for x in shape_and_dtypes: 218 if isinstance(x, ShapeAndDtype): 219 ret.append(array_ops.placeholder(x.dtype, x.shape)) 220 elif isinstance(x, (tuple, list)): 221 ret.append(_get_graph_callable_inputs(x)) 222 else: 223 raise errors.InvalidArgumentError( 224 None, None, "Expected the argument to @graph_callable to be a " 225 "(possibly nested) list or tuple of ShapeAndDtype objects, " 226 "but got an object of type: %s" % type(x)) 227 228 return tuple(ret) if isinstance(shape_and_dtypes, tuple) else ret 229 230 231def _graph_callable_internal(func, shape_and_dtypes): 232 """Defines and returns a template version of func. 233 234 Under the hood we make two function objects, each wrapping a different version 235 of the graph-mode code. One version immediately runs variable initialization 236 before making the variable's Tensors available for use, while the other 237 version replaces the Variables with placeholders which become function 238 arguments and get the current variable's value. 239 240 Limitations in (2) and (4) are because this does not implement a graph-mode 241 Variable class which has a convert_to_tensor(as_ref=True) method and a 242 initialized_value method. This is fixable. 243 244 Args: 245 func: The tfe Python function to compile. 246 shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects. 247 248 Raises: 249 ValueError: If any one of func's outputs is not a Tensor. 250 251 Returns: 252 Callable graph object. 253 """ 254 container = tf_ops.get_default_graph()._container # pylint: disable=protected-access 255 graph_key = tf_ops.get_default_graph()._graph_key # pylint: disable=protected-access 256 with context.graph_mode(): 257 # This graph will store both the initialization and the call version of the 258 # wrapped function. It will later be used by the backprop code to build the 259 # backprop graph, if necessary. 260 captures = {} 261 tmp_graph = function.CapturingGraph(captures) 262 # Inherit the graph key from the original graph to ensure optimizers don't 263 # misbehave. 264 tmp_graph._container = container # pylint: disable=protected-access 265 tmp_graph._graph_key = graph_key # pylint: disable=protected-access 266 with tmp_graph.as_default(): 267 # Placeholders for the non-variable inputs. 268 func_inputs = _get_graph_callable_inputs(shape_and_dtypes) 269 func_num_args = len(tf_inspect.getargspec(func).args) 270 if len(func_inputs) != func_num_args: 271 raise TypeError("The number of arguments accepted by the decorated " 272 "function `%s` (%d) must match the number of " 273 "ShapeAndDtype objects passed to the graph_callable() " 274 "decorator (%d)." % 275 (func.__name__, func_num_args, len(func_inputs))) 276 277 # First call the function to generate a graph which can initialize all 278 # variables. As a side-effect this will populate the variable capturing 279 # scope's view of which variables exist. 280 variable_captures = _VariableCapturingScope() 281 with variable_captures.initializing_scope(), function.capture_tensors( 282 captures): 283 func_outputs = func(*func_inputs) 284 outputs_list = nest.flatten(func_outputs) 285 if len(outputs_list) == 1 and outputs_list[0] is None: 286 outputs_list = [] 287 output_shapes = [x.shape for x in outputs_list] 288 if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list): 289 raise ValueError("Found non-tensor output in %s" % str(outputs_list)) 290 initializing_operations = tmp_graph.get_operations() 291 292 # Call the function again, now replacing usages of variables with 293 # placeholders. This assumes the variable capturing scope created above 294 # knows about all variables. 295 tmp_graph.clear_resource_control_flow_state() 296 with variable_captures.capturing_scope(), function.capture_tensors( 297 captures): 298 captured_outputs = func(*func_inputs) 299 captured_outlist = nest.flatten(captured_outputs) 300 capturing_operations = tmp_graph.get_operations()[ 301 len(initializing_operations):] 302 303 sorted_variables = sorted(variable_captures.variables.values(), 304 key=lambda x: x.name) 305 ids = list(sorted(captures.keys())) 306 if ids: 307 extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids]) 308 else: 309 extra_inputs = [] 310 extra_placeholders = [] 311 312 flat_inputs = [x for x in nest.flatten(func_inputs) 313 if isinstance(x, tf_ops.Tensor)] 314 placeholder_inputs = flat_inputs+ list(extra_placeholders) 315 316 func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)] 317 initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access 318 # TODO(ashankar): Oh lord, forgive me for this lint travesty. 319 # Also, what about the gradient registry of these functions? Those need to be 320 # addressed as well. 321 for f in tmp_graph._functions.values(): # pylint: disable=protected-access 322 function._register(f._c_func) # pylint: disable=protected-access 323 initializer_function = function.GraphModeFunction( 324 initialization_name, 325 placeholder_inputs, 326 extra_inputs, 327 tmp_graph, 328 initializing_operations, 329 func_def_outputs, 330 func_outputs, 331 output_shapes) 332 333 capture_func_def_outputs = [ 334 x for x in captured_outlist if isinstance(x, tf_ops.Tensor)] 335 captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access 336 captured_function = function.GraphModeFunction( 337 captured_function_name, 338 placeholder_inputs, 339 extra_inputs, 340 tmp_graph, 341 capturing_operations, 342 capture_func_def_outputs, 343 captured_outputs, 344 output_shapes, 345 variables=[x.variable for x in sorted_variables]) 346 347 return _InitializingFunctionObject(captured_function, initializer_function, 348 shape_and_dtypes) 349 350 351class ShapeAndDtype(object): 352 """Data type that packages together shape and type information. 353 354 Used for arguments to graph callables. See graph_callable() for an example. 355 """ 356 357 def __init__(self, shape, dtype): 358 self.shape = shape 359 self.dtype = dtype 360 361 362def graph_callable(shape_and_dtypes): 363 """Decorator that produces a callable that executes a TensorFlow graph. 364 365 When applied on a function that constructs a TensorFlow graph, this decorator 366 produces a callable object that: 367 368 1. Executes the graph when invoked. The first call will initialize any 369 variables defined in the graph. 370 371 2. Provides a .variables() method to return the list of TensorFlow variables 372 defined in the graph. 373 374 Note that the wrapped function is not allowed to change the values of the 375 variables, just use them. 376 377 The return value of the wrapped function must be one of the following: 378 (1) None, (2) a Tensor, or (3) a possibly nested sequence of Tensors. 379 380 Example: 381 382 ```python 383 @tfe.graph_callable([tfe.ShapeAndDtype(shape(), dtype=dtypes.float32)]) 384 def foo(x): 385 v = tf.get_variable('v', initializer=tf.ones_initializer(), shape=()) 386 return v + x 387 388 ret = foo(tfe.Tensor(2.0)) # `ret` here is a Tensor with value 3.0. 389 390 foo.variables[0].assign(7.0) # Modify the value of variable `v`. 391 ret = foo(tfe.Tensor(2.0)) # `ret` here now is a Tensor with value 9.0. 392 ``` 393 Args: 394 shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects 395 that specifies shape and type information for each of the callable's 396 arguments. The length of this list must be equal to the number of 397 arguments accepted by the wrapped function. 398 399 Returns: 400 A callable graph object. 401 """ 402 # TODO(alive,apassos): support initialized_value and friends from tf.Variable. 403 assert context.in_eager_mode(), ( 404 "graph_callable can only be used when Eager execution is enabled.") 405 def decorator(func): 406 return tf_decorator.make_decorator(func, 407 _graph_callable_internal( 408 func, shape_and_dtypes)) 409 410 return decorator 411