1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Decorator that produces a callable object that executes a TensorFlow graph.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23
24from tensorflow.python.eager import context
25from tensorflow.python.eager import function
26from tensorflow.python.eager import tape
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops as tf_ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import resource_variable_ops
33from tensorflow.python.ops import variable_scope
34from tensorflow.python.util import nest
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util import tf_inspect
37
38
39def _default_initializer(name, shape, dtype):
40  """The default initializer for variables."""
41  # pylint: disable=protected-access
42  store = variable_scope._get_default_variable_store()
43  initializer = store._get_default_initializer(name, shape=shape, dtype=dtype)
44  # pylint: enable=protected-access
45  return initializer[0]
46
47
48class _CapturedVariable(object):
49  """Variable captured by graph_callable.
50
51  Internal to the implementation of graph_callable. Created only by
52  _VariableCapturingScope and used only to read the variable values when calling
53  the function after the variables are initialized.
54  """
55
56  def __init__(self, name, initializer, shape, dtype, trainable):
57    self.name = name
58    if initializer is None:
59      initializer = _default_initializer(name, shape, dtype)
60    initial_value = lambda: initializer(shape, dtype=dtype)
61
62    with context.eager_mode():
63      self.variable = resource_variable_ops.ResourceVariable(
64          initial_value=initial_value, name=name, dtype=dtype,
65          trainable=trainable)
66    self.shape = shape
67    self.dtype = dtype
68    self.placeholder = None
69    self.trainable = trainable
70
71  def read(self, want_gradients=True):
72    if want_gradients and self.trainable:
73      v = tape.watch_variable(self.variable)
74    else:
75      v = self.variable
76    return v.read_value()
77
78
79class _VariableCapturingScope(object):
80  """Variable-scope-like object which captures tf.get_variable calls.
81
82  This is responsible for the main difference between the initialization version
83  of a function object and the calling version of a function object.
84
85  capturing_scope replaces calls to tf.get_variable with placeholder tensors to
86  be fed the variable's current value. TODO(apassos): these placeholders should
87  instead be objects implementing a similar API to tf.Variable, for full
88  compatibility.
89
90  initializing_scope replaces calls to tf.get_variable with creation of
91  variables and initialization of their values. This allows eventual support of
92  initialized_value and friends.
93
94  TODO(apassos): once the eager mode layers API is implemented support eager
95  func-to-object as well.
96  """
97
98  def __init__(self):
99    self.variables = {}
100    self.tf_variables = {}
101
102  @contextlib.contextmanager
103  def capturing_scope(self):
104    """Context manager to capture variable creations.
105
106    Replaces variable accesses with placeholders.
107
108    Yields:
109      nothing
110    """
111    # TODO(apassos) ignoring the regularizer and partitioner here; figure out
112    # how to deal with these.
113    def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32,  # pylint: disable=missing-docstring
114                       initializer=None, regularizer=None, reuse=None,
115                       trainable=True, collections=None, caching_device=None,  # pylint: disable=redefined-outer-name
116                       partitioner=None, validate_shape=True,
117                       use_resource=None):
118      del getter, regularizer, partitioner, validate_shape, use_resource, dtype
119      del collections, initializer, trainable, reuse, caching_device, shape,
120      assert name in self.variables
121      v = self.variables[name]
122      return v.variable
123
124    scope = variable_scope.get_variable_scope()
125    with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
126      yield
127
128  @contextlib.contextmanager
129  def initializing_scope(self):
130    """Context manager to capture variable creations.
131
132    Forcibly initializes all created variables.
133
134    Yields:
135      nothing
136    """
137    # TODO(apassos) ignoring the regularizer and partitioner here; figure out
138    # how to deal with these.
139    def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32,  # pylint: disable=missing-docstring
140                       initializer=None, regularizer=None, reuse=None,
141                       trainable=True, collections=None, caching_device=None,  # pylint: disable=redefined-outer-name
142                       partitioner=None, validate_shape=True,
143                       use_resource=None):
144      del getter, regularizer, collections, caching_device, partitioner
145      del use_resource, validate_shape
146      if name in self.tf_variables:
147        if reuse:
148          return self.tf_variables[name].initialized_value()
149        else:
150          raise ValueError("Specified reuse=%s but tried to reuse variables."
151                           % reuse)
152      # TODO(apassos): ensure this is on the same device as above
153      v = _CapturedVariable(name, initializer, shape, dtype, trainable)
154      self.variables[name] = v
155
156      graph_mode_resource = v.variable.handle
157      if initializer is None:
158        initializer = _default_initializer(name, shape, dtype)
159      resource_variable_ops.shape_safe_assign_variable_handle(
160          graph_mode_resource, v.variable.shape, initializer(shape, dtype))
161      return v.variable
162
163    scope = variable_scope.get_variable_scope()
164    with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
165      yield
166
167
168class _InitializingFunctionObject(object):
169  """Responsible for deciding which version of func-to-object to call.
170
171  call_fn is the version which calls the function with the current values of the
172  variables and init_fn is the version which calls the function to initialize
173  all variables.
174
175  TODO(apassos): figure out a way to support initializing only _some_
176  variables. This requires a way to pull out a variable's initialization code
177  from the graph, which might not be possible in general.
178  """
179
180  def __init__(self, call_fn, init_fn, shape_and_dtypes):
181    self._init_fn = init_fn
182    self._call_fn = call_fn
183    self.shape_and_dtypes = shape_and_dtypes
184    self.flattened_shapes = [tensor_shape.as_shape(sd.shape) for sd in
185                             nest.flatten(self.shape_and_dtypes)]
186
187  @property
188  def variables(self):
189    return self._call_fn.variables
190
191  def __call__(self, *args):
192    nest.assert_same_structure(self.shape_and_dtypes, args, check_types=False)
193    if not all([
194        shape.is_compatible_with(arg.shape)
195        for shape, arg in zip(self.flattened_shapes, nest.flatten(args))
196    ]):
197      raise ValueError(
198          "Declared shapes do not match argument shapes: Expected %s, found %s."
199          % (self.flattened_shapes, [arg.shape for arg in nest.flatten(args)]))
200
201    initialized = [resource_variable_ops.var_is_initialized_op(
202        v.handle).numpy() for v in self._call_fn.variables]
203    if all(x for x in initialized):
204      for v in self._call_fn.variables:
205        if v._trainable:  # pylint: disable=protected-access
206          tape.watch_variable(v)
207      return self._call_fn(*args)
208    elif all(not x for x in initialized):
209      return self._init_fn(*args)
210    else:
211      raise ValueError("Some, but not all, variables are initialized.")
212
213
214def _get_graph_callable_inputs(shape_and_dtypes):
215  """Maps specified shape_and_dtypes to graph inputs."""
216  ret = []
217  for x in shape_and_dtypes:
218    if isinstance(x, ShapeAndDtype):
219      ret.append(array_ops.placeholder(x.dtype, x.shape))
220    elif isinstance(x, (tuple, list)):
221      ret.append(_get_graph_callable_inputs(x))
222    else:
223      raise errors.InvalidArgumentError(
224          None, None, "Expected the argument to @graph_callable to be a "
225          "(possibly nested) list or tuple of ShapeAndDtype objects, "
226          "but got an object of type: %s" % type(x))
227
228  return tuple(ret) if isinstance(shape_and_dtypes, tuple) else ret
229
230
231def _graph_callable_internal(func, shape_and_dtypes):
232  """Defines and returns a template version of func.
233
234  Under the hood we make two function objects, each wrapping a different version
235  of the graph-mode code. One version immediately runs variable initialization
236  before making the variable's Tensors available for use, while the other
237  version replaces the Variables with placeholders which become function
238  arguments and get the current variable's value.
239
240  Limitations in (2) and (4) are because this does not implement a graph-mode
241  Variable class which has a convert_to_tensor(as_ref=True) method and a
242  initialized_value method. This is fixable.
243
244  Args:
245    func: The tfe Python function to compile.
246    shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects.
247
248  Raises:
249    ValueError: If any one of func's outputs is not a Tensor.
250
251  Returns:
252    Callable graph object.
253  """
254  container = tf_ops.get_default_graph()._container  # pylint: disable=protected-access
255  graph_key = tf_ops.get_default_graph()._graph_key  # pylint: disable=protected-access
256  with context.graph_mode():
257    # This graph will store both the initialization and the call version of the
258    # wrapped function. It will later be used by the backprop code to build the
259    # backprop graph, if necessary.
260    captures = {}
261    tmp_graph = function.CapturingGraph(captures)
262    # Inherit the graph key from the original graph to ensure optimizers don't
263    # misbehave.
264    tmp_graph._container = container  # pylint: disable=protected-access
265    tmp_graph._graph_key = graph_key  # pylint: disable=protected-access
266    with tmp_graph.as_default():
267      # Placeholders for the non-variable inputs.
268      func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
269      func_num_args = len(tf_inspect.getargspec(func).args)
270      if len(func_inputs) != func_num_args:
271        raise TypeError("The number of arguments accepted by the decorated "
272                        "function `%s` (%d) must match the number of "
273                        "ShapeAndDtype objects passed to the graph_callable() "
274                        "decorator (%d)." %
275                        (func.__name__, func_num_args, len(func_inputs)))
276
277      # First call the function to generate a graph which can initialize all
278      # variables. As a side-effect this will populate the variable capturing
279      # scope's view of which variables exist.
280      variable_captures = _VariableCapturingScope()
281      with variable_captures.initializing_scope(), function.capture_tensors(
282          captures):
283        func_outputs = func(*func_inputs)
284      outputs_list = nest.flatten(func_outputs)
285      if len(outputs_list) == 1 and outputs_list[0] is None:
286        outputs_list = []
287      output_shapes = [x.shape for x in outputs_list]
288      if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
289        raise ValueError("Found non-tensor output in %s" % str(outputs_list))
290      initializing_operations = tmp_graph.get_operations()
291
292      # Call the function again, now replacing usages of variables with
293      # placeholders. This assumes the variable capturing scope created above
294      # knows about all variables.
295      tmp_graph.clear_resource_control_flow_state()
296      with variable_captures.capturing_scope(), function.capture_tensors(
297          captures):
298        captured_outputs = func(*func_inputs)
299      captured_outlist = nest.flatten(captured_outputs)
300      capturing_operations = tmp_graph.get_operations()[
301          len(initializing_operations):]
302
303  sorted_variables = sorted(variable_captures.variables.values(),
304                            key=lambda x: x.name)
305  ids = list(sorted(captures.keys()))
306  if ids:
307    extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
308  else:
309    extra_inputs = []
310    extra_placeholders = []
311
312  flat_inputs = [x for x in nest.flatten(func_inputs)
313                 if isinstance(x, tf_ops.Tensor)]
314  placeholder_inputs = flat_inputs+ list(extra_placeholders)
315
316  func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
317  initialization_name = function._inference_name(func.__name__)  # pylint: disable=protected-access
318  # TODO(ashankar): Oh lord, forgive me for this lint travesty.
319  # Also, what about the gradient registry of these functions? Those need to be
320  # addressed as well.
321  for f in tmp_graph._functions.values():  # pylint: disable=protected-access
322    function._register(f._c_func)  # pylint: disable=protected-access
323  initializer_function = function.GraphModeFunction(
324      initialization_name,
325      placeholder_inputs,
326      extra_inputs,
327      tmp_graph,
328      initializing_operations,
329      func_def_outputs,
330      func_outputs,
331      output_shapes)
332
333  capture_func_def_outputs = [
334      x for x in captured_outlist if isinstance(x, tf_ops.Tensor)]
335  captured_function_name = function._inference_name(func.__name__)  # pylint: disable=protected-access
336  captured_function = function.GraphModeFunction(
337      captured_function_name,
338      placeholder_inputs,
339      extra_inputs,
340      tmp_graph,
341      capturing_operations,
342      capture_func_def_outputs,
343      captured_outputs,
344      output_shapes,
345      variables=[x.variable for x in sorted_variables])
346
347  return _InitializingFunctionObject(captured_function, initializer_function,
348                                     shape_and_dtypes)
349
350
351class ShapeAndDtype(object):
352  """Data type that packages together shape and type information.
353
354  Used for arguments to graph callables. See graph_callable() for an example.
355  """
356
357  def __init__(self, shape, dtype):
358    self.shape = shape
359    self.dtype = dtype
360
361
362def graph_callable(shape_and_dtypes):
363  """Decorator that produces a callable that executes a TensorFlow graph.
364
365  When applied on a function that constructs a TensorFlow graph, this decorator
366  produces a callable object that:
367
368  1. Executes the graph when invoked. The first call will initialize any
369     variables defined in the graph.
370
371  2. Provides a .variables() method to return the list of TensorFlow variables
372     defined in the graph.
373
374  Note that the wrapped function is not allowed to change the values of the
375  variables, just use them.
376
377  The return value of the wrapped function must be one of the following:
378  (1) None,  (2) a Tensor, or (3) a possibly nested sequence of Tensors.
379
380  Example:
381
382  ```python
383  @tfe.graph_callable([tfe.ShapeAndDtype(shape(), dtype=dtypes.float32)])
384  def foo(x):
385    v = tf.get_variable('v', initializer=tf.ones_initializer(), shape=())
386    return v + x
387
388  ret = foo(tfe.Tensor(2.0))  # `ret` here is a Tensor with value 3.0.
389
390  foo.variables[0].assign(7.0)  # Modify the value of variable `v`.
391  ret = foo(tfe.Tensor(2.0))  # `ret` here now is a Tensor with value 9.0.
392  ```
393  Args:
394    shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects
395      that specifies shape and type information for each of the callable's
396      arguments. The length of this list must be equal to the number of
397      arguments accepted by the wrapped function.
398
399  Returns:
400    A callable graph object.
401  """
402  # TODO(alive,apassos): support initialized_value and friends from tf.Variable.
403  assert context.in_eager_mode(), (
404      "graph_callable can only be used when Eager execution is enabled.")
405  def decorator(func):
406    return tf_decorator.make_decorator(func,
407                                       _graph_callable_internal(
408                                           func, shape_and_dtypes))
409
410  return decorator
411