function.py revision 90e42f3ac8c43474633136af4242dca04b6a1e09
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Python front-end supports for functions.
16
17NOTE: functions are currently experimental and subject to change!
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import collections
25import hashlib
26
27from tensorflow.core.framework import attr_value_pb2
28from tensorflow.core.framework import function_pb2
29from tensorflow.python import pywrap_tensorflow as c_api
30from tensorflow.python.eager import context
31from tensorflow.python.framework import c_api_util
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors
34from tensorflow.python.framework import graph_to_function_def
35from tensorflow.python.framework import ops
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import resource_variable_ops
38from tensorflow.python.ops import variable_scope as vs
39from tensorflow.python.util import compat
40from tensorflow.python.util import tf_decorator
41from tensorflow.python.util import tf_inspect
42
43
44class Defun(object):
45  """Decorator used to define TensorFlow functions.
46
47  Use this decorator to make a Python function usable directly as a TensorFlow
48  function.
49
50  The decorated function must add ops to the default graph and return zero or
51  more `Tensor` objects.  Call the decorator with named arguments, one for each
52  argument of the function to decorate, with the expected type of the argument
53  as value.
54
55  For example if the function to decorate accepts two `tf.float32` arguments
56  named `x` and `y`, call the decorator with:
57
58      @Defun(tf.float32, tf.float32)
59      def foo(x, y):
60        ...
61
62  When you call the decorated function it will add `call` ops to the
63  default graph and adds the definition of the function into the
64  default graph. Because the addition of the function into the graph
65  is deferred, the decorator can be used anywhere in the program.
66
67  Any variables created inside of the function are hoisted into the outer graph.
68  Note that the variables are created in the variable scope that was active
69  during the first call to the function. Subsequent function calls will refer to
70  the same set of variables.
71
72  Definitions of functions are frozen in a graph as soon as the graph is used to
73  create a session. Therefore, nodes using the function must be created in the
74  graph before the corresponding session is created.
75
76  Example, but also see the [How To on functions](link_needed).
77
78  ```python
79  # Defining the function.
80  @tf.Defun(tf.float32, tf.float32)
81  def MyFunc(x, y):
82    return x + y, x - y
83
84  # Building the graph.
85  a = tf.constant([1.0])
86  b = tf.constant([2.0])
87  c, d = MyFunc(a, b, name='mycall')
88  ```
89  """
90
91  def __init__(self, *input_types, **kwargs):
92    """Create a `Defun` decorator.
93
94    Args:
95      *input_types: A list of `tf.DType`
96      **kwargs: Optional keyword arguments, including
97         func_name - (optional).  A python string, the name to use to
98           declare this `Function` in the graph.
99
100         grad_func - (optional).  A function implementing the gradient
101           of the function-to-register.  This is must be a
102           `_DefinedFunction` object. The gradient
103           function must satisfy the criterion defined in
104           function.proto:GradientDef.
105
106         python_grad_func - (optional).  A function implementing the
107           gradient of the function python-side. This function must
108           take the current op and the gradients w.r.t. its outputs,
109           and return the gradients w.r.t. the inputs. That is it must
110           implement the interface expected by `tf.RegisterGradient`).
111           This will be called by tf.gradients to add the gradient ops
112           to the graph. At most one of grad_func and python_grad_func
113           can be specified.
114
115         out_names = (optional). A list of strings, one per output
116           tensor.
117
118         shape_func - (optional). A function taking the op and returning a list
119           of static shapes to set for the function's outputs.
120    """
121    self._input_types = input_types
122    self._func_name = kwargs.pop("func_name", None)
123    self._grad_func = kwargs.pop("grad_func", None)
124    self._python_grad_func = kwargs.pop("python_grad_func", None)
125    self._out_names = kwargs.pop("out_names", None)
126    self._extra_kwargs = kwargs
127
128  def __call__(self, func):
129    # Various sanity checks on the callable func.
130    if not callable(func):
131      raise ValueError("func %s must be callable" % func)
132
133    # Func should not use kwargs and defaults.
134    argspec = tf_inspect.getargspec(func)
135    if argspec.keywords or argspec.defaults:
136      raise ValueError("Functions with argument defaults or keyword "
137                       "arguments are not supported.")
138
139    # Computes how many arguments 'func' has.
140    min_args = len(argspec.args)
141    max_args = min_args
142    if argspec.varargs:
143      max_args = 1000000
144    argnames = argspec.args
145    if tf_inspect.ismethod(func):
146      # 1st argument is the "class" type.
147      min_args -= 1
148      argnames = argnames[1:]
149
150    if self._input_types:
151      # If Defun is given a list of types for the inputs, the number
152      # of input types should be compatible with 'func'.
153      num = len(self._input_types)
154      if num < min_args or num > max_args:
155        raise ValueError(
156            "The function has fewer arguments than the number of specified "
157            "input types.")
158      return _DefinedFunction(
159          func,
160          argnames,
161          self._input_types,
162          self._func_name,
163          self._grad_func,
164          self._python_grad_func,
165          out_names=self._out_names,
166          **self._extra_kwargs)
167
168    # 'func' expects no arguments and input types is an empty list.
169    if min_args == 0 and max_args == 0:
170      return _DefinedFunction(
171          func, [], [],
172          self._func_name,
173          self._grad_func,
174          self._python_grad_func,
175          out_names=self._out_names,
176          **self._extra_kwargs)
177
178    # Input types are unknown. It's an overloaded function and hence
179    # its definition needs to be deferred until it's called.
180    return _OverloadedFunction(
181        func,
182        argnames,
183        self._func_name,
184        self._grad_func,
185        self._python_grad_func,
186        out_names=self._out_names,
187        **self._extra_kwargs)
188
189
190class _DefinedFunction(object):
191  """_DefinedFunction encapsulates a function definition and its properties.
192
193  Attributes:
194    name: The function name.
195    definition: The definition of this function. A FunctionDef proto.
196    grad_func_name: If not None, the name of this function's gradient function.
197    python_grad_func: A python callable implementing the gradient of
198      the function python-side.
199  """
200
201  def __init__(self,
202               func,
203               argnames,
204               input_types,
205               func_name=None,
206               grad_func=None,
207               python_grad_func=None,
208               out_names=None,
209               shape_func=None,
210               capture_by_value=False,
211               **kwargs):
212    """Creates _DefinedFunction.
213
214    Args:
215      func:  A python callable which constructs a tf function body.
216      argnames: A list of strings for function argument names.
217      input_types: The function's argument types. Can be a tuple, list of
218        tf data types.
219      func_name: The function name. Defaults to None, in which derives from
220        'func'.
221      grad_func: This function's gradient function, if not None. Defaults
222        to None.
223      python_grad_func: A python callable implementing the gradient of
224        the function python-side.
225      out_names: An optional list of strings for the function return value
226        names.
227      shape_func: An optional function mapping an op to a list of static
228        output shapes.
229      capture_by_value: Boolean (defaults to False). If True, captured values
230        will be copied into the function body.
231      **kwargs: The keyword arguments. **kwargs is passed to every call
232        site of this function.
233
234    Raises:
235      ValueError: The function definition is invalid.
236
237    """
238    self._func = func
239    self._input_types = input_types
240    self._func_name = func_name
241    self._grad_func = grad_func
242    self._python_grad_func = python_grad_func
243    self._out_names = out_names
244    self._shape_func = shape_func
245    self._capture_by_value = capture_by_value
246    self._extra_kwargs = kwargs
247    # Constructed only when C API is disabled, lazily
248    self._definition = None
249    # Constructed only when C API is enabled, lazily
250    self._c_func = None
251    self._sub_functions = dict()  # Constructed with _definition or _c_func
252
253    # Cached OpDef for this function. When C API is enabled, this is
254    # the only part of FunctionDef that we cache in Python. When C API
255    # is disabled the whole _definition is available and this is simply
256    # another reference to _definition.signature
257    self._op_def = None
258
259    self._args = []
260    assert isinstance(input_types, (list, tuple))
261    for i in range(len(input_types)):
262      argname = argnames[i] if i < len(argnames) else ("arg%d" % i)
263      argtype = input_types[i]
264      self._args.append((argname, argtype))
265
266  @property
267  def name(self):
268    """Function name."""
269    self._create_definition_if_needed()
270    return self._func_name
271
272  @property
273  def definition(self):
274    """Function definition proto."""
275    self._create_definition_if_needed()
276    if self._c_func:
277      with c_api_util.tf_buffer() as buf:
278        with errors.raise_exception_on_not_ok_status() as status:
279          c_api.TF_FunctionToFunctionDef(self._c_func, buf, status)
280        fdef = function_pb2.FunctionDef()
281        proto_data = c_api.TF_GetBuffer(buf)
282        fdef.ParseFromString(compat.as_bytes(proto_data))
283      return fdef
284    return self._definition
285
286  @property
287  def _signature(self):
288    self._create_definition_if_needed()
289    return self._op_def
290
291  def set_grad_func(self, grad_func):
292    """Specifies the gradient function of this function."""
293    assert not self._grad_func
294    assert isinstance(grad_func, _DefinedFunction)
295    self._grad_func = grad_func
296
297  @property
298  def grad_func_name(self):
299    """Its gradient function's name."""
300    return self._grad_func.name if self._grad_func else None
301
302  @property
303  def python_grad_func(self):
304    """Python gradient function callable."""
305    return self._python_grad_func
306
307  @property
308  def declared_input_types(self):
309    """Returns the list of data types of explicit declared inputs."""
310    return self._input_types
311
312  @property
313  def captured_inputs(self):
314    """Returns the list of implicitly captured inputs."""
315    self._create_definition_if_needed()
316    return self._extra_inputs
317
318  def _create_definition_if_needed(self):
319    """Creates the function definition if it's not created yet."""
320    with context.graph_mode():
321      self._create_definition_if_needed_impl()
322
323  def _create_definition_if_needed_impl(self):
324    """This is not what you want, see _create_definition_if_needed."""
325    if self._definition is not None or self._c_func is not None:
326      return
327
328    # Create the func_def object.
329    temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
330    with temp_graph.as_default():
331      # List of placeholders for the function_def.
332      inputs = []
333      for (argname, argtype) in self._args:
334        argholder = array_ops.placeholder(argtype, name=argname)
335        inputs.append(argholder)
336      # Call func and gather the output tensors.
337      with vs.variable_scope("", custom_getter=temp_graph.getvar):
338        outputs = self._func(*inputs)
339
340      # There is no way of distinguishing between a function not returning
341      # anything and a function returning None in Python.
342      # We need to allow the former and ideally want to forbid the latter as
343      # it is most likely user error.
344      # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
345      # allow users to explicitly mark the function as not returning anything.
346      # For now, we allow a single None return and interpret it as a function
347      # with no output.
348      if outputs is None:
349        outputs = []
350      else:
351        # If func only returned one value, make it a tuple.
352        if not isinstance(outputs, (list, tuple)):
353          outputs = (outputs,)
354        if any([_ is None for _ in outputs]):
355          raise ValueError("Function can not return None.")
356      # Ensures each output is a Tensor.
357      outputs = [ops.convert_to_tensor(_) for _ in outputs]
358    self._extra_inputs = temp_graph.extra_inputs
359    inputs.extend(temp_graph.extra_args)
360    # pylint: disable=protected-access
361    self._sub_functions = temp_graph._functions
362    # pylint: enable=protected-access
363
364    # Extra kwargs are treated as attrs on the function def.
365    base_func_name = self._func_name or _get_func_name(self._func)
366    kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
367                                         **self._extra_kwargs)
368
369    if not temp_graph._c_graph:  # pylint: disable=protected-access
370      # Build the FunctionDef
371      self._definition = graph_to_function_def.graph_to_function_def(
372          temp_graph,
373          temp_graph.get_operations(),
374          inputs,
375          outputs,
376          out_names=self._out_names)
377
378      for k in kwargs_attr:
379        self._definition.attr[k].CopyFrom(kwargs_attr[k])
380
381      # Hash the definition and its dependencies.
382      self._hash_str = self._create_hash_str(
383          self._definition.signature.input_arg,
384          self._definition.signature.output_arg, self._definition.node_def)
385
386      # Finally, we decide the function name to use.  If not specified,
387      # make up something which is almost certainly unique (but deterministic).
388      if not self._func_name:
389        self._func_name = "_".join([base_func_name, self._hash_str])
390      self._definition.signature.name = self._func_name
391      if self._func.__doc__:
392        self._definition.signature.description = self._func.__doc__
393
394      self._op_def = self._definition.signature
395    else:  # C API is enabled
396      output_names = ([compat.as_bytes(x) for x in self._out_names]
397                      if self._out_names else [])
398      description = self._func.__doc__ or None
399      # pylint: disable=protected-access
400      with errors.raise_exception_on_not_ok_status() as status:
401        self._c_func = c_api.TF_GraphToFunction_wrapper(
402            temp_graph._c_graph,
403            base_func_name,
404            self._func_name is None,  # append_hash_to_fn_name
405            None,  # opers
406            [t._as_tf_output() for t in inputs],
407            [t._as_tf_output() for t in outputs],
408            output_names,
409            None,  # opts
410            description,
411            status)
412      # pylint: enable=protected-access
413      self._set_c_attrs(kwargs_attr)
414
415      # Set cached fields: _op_def and _func_name (if not already set)
416      self._op_def = self.definition.signature
417      if self._func_name:
418        assert self._func_name == self._op_def.name
419      else:
420        self._func_name = self._op_def.name
421
422  def _set_c_attrs(self, attrs):
423    """Sets `attrs` as attributes of self._c_func.
424
425    Requires that self._c_func is not None.
426
427    Args:
428      attrs: a dictionary from attribute name to attribute proto value
429    """
430    for name, attr_value in attrs.items():
431      serialized = attr_value.SerializeToString()
432      # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
433      # It might be worth creating a convenient way to re-use the same status.
434      with errors.raise_exception_on_not_ok_status() as status:
435        c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
436                                           serialized, status)
437
438  def _create_hash_str(self, input_arg, output_arg, node_def):
439    """Creates an 8-character string unique to this input.
440
441    Args:
442      input_arg: the input_arg field of an OpDef
443                 (e.g. self._definition.signature.input_arg)
444      output_arg: the output_arg field of an OpDef
445                 (e.g. self._definition.signature.output_arg)
446      node_def: the node_def field of a FunctionDef
447                (e.g. self._definition.node_def)
448
449    Returns:
450      The unique string for this input
451    """
452    hasher = hashlib.sha1()
453
454    def update_num(n):
455      hasher.update(compat.as_bytes("%x" % n))
456
457    def update_str(s):
458      update_num(len(s))
459      hasher.update(compat.as_bytes(s))
460
461    def update_strs(slist):
462      update_num(len(slist))
463      for s in slist:
464        update_str(s)
465
466    for adef in input_arg:
467      update_str(adef.SerializeToString())
468
469    for adef in output_arg:
470      update_str(adef.SerializeToString())
471
472    for n in sorted(node_def, key=lambda n: n.name):
473      update_str(n.name)
474      update_str(n.op)
475      update_strs(n.input)
476      update_num(len(n.attr))
477      # NOTE: protobuf map serialization does not guarantee ordering.
478      for k in sorted(n.attr):
479        update_str(k)
480        update_str(n.attr[k].SerializeToString())
481
482    return hasher.hexdigest()[:8]
483
484  def add_to_graph(self, g):
485    """Adds this function into the graph g."""
486    self._create_definition_if_needed()
487
488    # Adds this function into 'g'.
489    # pylint: disable=protected-access
490    if context.in_graph_mode():
491      g._add_function(self)
492    else:
493      context.context().add_function_def(self.definition)
494    # pylint: enable=protected-access
495
496    # Ensures related sub-routines are defined in 'g', too.
497    for f in self._sub_functions.values():
498      f.add_to_graph(g)
499
500    # Adds its gradient function, too.
501    if self._grad_func:
502      self._grad_func.add_to_graph(g)
503
504  def __call__(self, *args, **kwargs):
505    self.add_to_graph(ops.get_default_graph())
506    args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs
507    ret, op = _call(self._signature, *args, **kwargs)
508    if self._shape_func is not None:
509      shapes = self._shape_func(op)
510      if len(shapes) != len(op.outputs):
511        raise ValueError("shape_func produced %d shapes for %d outputs" %
512                         (len(shapes), len(op.outputs)))
513      for (t, shape) in zip(op.outputs, shapes):
514        t.set_shape(shape)
515    return ret
516
517
518class _OverloadedFunction(object):
519  """_OverloadedFunction encapsulates an overloaded function.
520
521  _OverloadedFunction maintains a mapping from input types to
522  instantiated _DefinedFunction in self._overload.
523
524  """
525
526  def __init__(self,
527               func,
528               argnames,
529               func_name=None,
530               grad_func=None,
531               python_grad_func=None,
532               out_names=None,
533               **kwargs):
534    """Creates _DefinedFunction.
535
536    Args:
537      func:  A python callable which constructs a tf function body.
538      argnames: A list of strings for function argument names.
539      func_name: The function name. Defaults to None, in which derives from
540        'func'.
541      grad_func: This function's gradient function, if not None. Defaults
542        to None.
543      python_grad_func: A python callable implementing the gradient of
544        the function python-side.
545      out_names: A list of strings for the function return value names.
546      **kwargs: The keyword arguments. **kwargs is passed to every call
547        site of this function.
548
549    Raises:
550      ValueError: The function definition is invalid.
551
552    """
553    self._func = func
554    self._argnames = argnames
555    self._func_name = func_name
556    assert grad_func is None or isinstance(grad_func, _OverloadedFunction)
557    self._grad_func = grad_func
558    self._python_grad_func = python_grad_func
559    self._out_names = out_names
560    self._extra_kwargs = kwargs
561    self._overload = {}
562
563  def instantiate(self, input_types):
564    """Instantiate this function given input argument types.
565
566    Args:
567      input_types: A list of data types for the inputs.
568
569    Returns:
570      _DefinedFunction for the given input types.
571
572    """
573    # Stringify the type list.
574    key = _type_list_to_str(input_types)
575    defined = self._overload.get(key)
576    if not defined:
577      # If not defined yet, define the function given the input types.
578      name = self._func_name
579      if name is not None:
580        name = "_".join([name, key])
581      defined = _DefinedFunction(
582          self._func,
583          self._argnames,
584          input_types,
585          name,
586          None,
587          self._python_grad_func,
588          out_names=self._out_names,
589          **self._extra_kwargs)
590      _ = defined.name  # Fully instantiate the function definition.
591      if self._grad_func:
592        # If _grad_func is given, it is another
593        # _OverloadedFunction. We need to instantiate it with the
594        # right input types.
595        output_types = [
596            dtypes.DType(_.type)
597            for _ in defined._signature.output_arg  # pylint: disable=protected-access
598        ]
599        # pylint: disable=protected-access
600        defined._grad_func = self._grad_func.instantiate(
601            input_types + output_types)
602        # pylint: enable=protected-access
603      self._overload[key] = defined
604    return defined
605
606  def __call__(self, *args, **kwargs):
607    input_types = []
608    args = list(args)
609    for (i, x) in enumerate(args):
610      x = ops.convert_to_tensor(x)
611      if not isinstance(x, ops.Tensor):
612        raise ValueError("Expect a Tensor but get ", x)
613      input_types.append(x.dtype)
614      args[i] = x
615    return self.instantiate(input_types)(*args, **kwargs)
616
617
618class _FuncGraph(ops.Graph):
619  """A helper for constructing a function.
620
621  _FuncGraph overrides ops.Graph's create_op() so that we can keep
622  track of all inputs into every op created inside the function.  If
623  any input is from other graphs, we keep track of it in self.capture
624  and substitute the input with a place holder.
625
626  Each captured input's corresponding place holder is converted into a
627  function argument and the caller passes in the captured tensor.
628  """
629
630  def __init__(self, capture_by_value, *args, **kwargs):
631    super(_FuncGraph, self).__init__(*args, **kwargs)
632    self._capture_by_value = capture_by_value
633    self._building_function = True
634    self._outer_graph = ops.get_default_graph()
635    self._vscope = vs.get_variable_scope()
636    self._old_custom_getter = self._vscope.custom_getter
637    self._captured = {}
638    self.extra_inputs = []
639    self.extra_args = []
640    self.extra_vars = []
641
642  def getvar(
643      self,
644      getter,
645      name,
646      shape=None,
647      dtype=None,
648      initializer=None,
649      reuse=None,
650      trainable=True,
651      collections=None,  # pylint: disable=redefined-outer-name
652      use_resource=None,
653      **kwargs):
654    """A custom variable getter."""
655    # Here, we switch the default graph to the outer graph and ask the
656    # variable scope in which the function is defined to give us the
657    # variable. The variable is stashed in extra_vars and returned to
658    # the caller.
659    #
660    # We capture these variables so that the variable definition is
661    # hoisted upward to the outer most graph.
662    with self._outer_graph.as_default():
663      # pylint: disable=protected-access
664      var = self._vscope.get_variable(
665          vs._get_default_variable_store(),
666          name,
667          shape=shape,
668          dtype=dtype,
669          initializer=initializer,
670          reuse=reuse,
671          trainable=trainable,
672          collections=collections,
673          use_resource=use_resource)
674      self.extra_vars.append(var)
675      if isinstance(var, resource_variable_ops.ResourceVariable):
676        # For resource-based variables read the variable outside the function
677        # and pass in the value. This ensures that the function is pure and
678        # differentiable. TODO(apassos) this may have performance problems if
679        # the function will only do embedding lookups on the variable.
680        return var.value()
681      return var
682
683  def create_op(self, op_type, inputs, data_types, **kwargs):
684    for i, x in enumerate(inputs):
685      if x.graph is not self:
686        # Referring to a tensor from other graph.
687        if x in self._captured:
688          # Captured already.
689          inputs[i] = self._captured[x]
690        elif self._capture_by_value:
691          inputs[i] = self._add_tensor_and_parents(x)
692        else:
693          # Substitute with a placeholder.
694          self.extra_inputs.append(x)
695          # Hoist the new input placeholder out of any control flow context
696          # we're currently in.
697          with ops.control_dependencies(None):
698            ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
699          # pylint: disable=protected-access
700          ph._handle_data = x._handle_data
701          # pylint: enable=protected-access
702          inputs[i] = ph
703          self._captured[x] = ph
704          self.extra_args.append(ph)
705    return super(_FuncGraph, self).create_op(op_type, inputs, data_types,
706                                             **kwargs)
707
708  def _add_tensor_and_parents(self, tensor):
709    op = self._add_op_and_parents(tensor.op)
710    return op.outputs[tensor.value_index]
711
712  def _add_op_and_parents(self, op):
713    # pylint: disable=protected-access
714    op_def = graph_to_function_def._get_op_def(op)
715    # pylint: enable=protected-access
716    if op_def.is_stateful:
717      raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
718                       "by value." % (op.name, op.type))
719    elif op.type in ("Placeholder", "PlaceholderV2"):
720      raise ValueError("Cannot capture a placeholder (name:%s, type:%s) "
721                       "by value." % (op.name, op.type))
722
723    captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
724
725    captured_op = self.create_op(
726        op.type,
727        captured_inputs, [o.dtype for o in op.outputs],
728        name=op.name,
729        attrs=op.node_def.attr,
730        op_def=op_def)
731
732    for t, captured_t in zip(op.outputs, captured_op.outputs):
733      self._captured[t] = captured_t
734
735    return captured_op
736
737
738def _call(sig, *inputs, **kwargs):
739  """Adds a node calling a function.
740
741  This adds a `call` op to the default graph that calls the function
742  of signature `sig`, passing the tensors in `inputs` as arguments.
743  It returns the outputs of the call, which are one or more tensors.
744
745  `sig` is OpDefArg.a `_DefinedFunction` object.
746
747  You can pass an optional keyword parameter `name=string` to name the
748  added operation.
749
750  You can pass an optional keyword parameter `noinline=True|False` to
751  instruct the runtime not to inline the function body into the call
752  site.
753
754  Args:
755    sig: OpDefArg. The signature of the function.
756    *inputs: arguments to the function.
757    **kwargs: Optional keyword arguments.  Can only contain 'name' or
758        'noinline'.
759
760  Returns:
761     A 2-element tuple. First element: a Tensor if the function returns a single
762     value; a list of Tensors if the function returns multiple value; the
763     Operation if the function returns no values. Second element: the Operation.
764
765  Raises:
766    ValueError: if the arguments are invalid.
767  """
768  if len(inputs) != len(sig.input_arg):
769    raise ValueError("Expected number of arguments: %d, received: %d" %
770                     (len(sig.input_arg), len(inputs)))
771  name = kwargs.pop("name", None)
772  g = ops.get_default_graph()
773  func_name = sig.name
774  attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
775  output_types = [dtypes.DType(x.type) for x in sig.output_arg]
776  with ops.name_scope(name, func_name, inputs) as name:
777    op = g.create_op(
778        func_name,
779        list(inputs),
780        output_types,
781        name=name,
782        attrs=attrs,
783        op_def=sig,
784        compute_shapes=False)
785  if op.outputs:
786    if len(op.outputs) == 1:
787      ret = op.outputs[0]
788    else:
789      ret = tuple(op.outputs)
790  else:
791    ret = op
792  return ret, op
793
794
795def _from_definition(fdef, grad_func=None):
796  """Creates a _DefinedFunction initialized from a FunctionDef proto.
797
798  Args:
799    fdef: a FunctionDef
800    grad_func: a _DefinedFunction or None
801
802  Returns:
803    A _DefinedFunction representing fdef
804  """
805  # TODO(iga): This method does major surgery on _DefinedFunction.
806  # Make it a named constructor using @classmethod of _DefinedFunction.
807
808  # The Python callable is only needed to create a FunctionDef. Since we have
809  # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we
810  # have access to such a callable here).
811  func = None
812  argnames = [arg.name for arg in fdef.signature.input_arg]
813  input_types = tuple(
814      dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
815  func_name = fdef.signature.name
816  # Note: FunctionDefs do not include python gradient functions, so if the
817  # original _DefinedFunction included one it will not be reflected here.
818  python_grad_func = None
819  out_names = [arg.name for arg in fdef.signature.output_arg]
820  result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
821                            python_grad_func, out_names)
822  # pylint: disable=protected-access
823  if ops._USE_C_API:
824    serialized = fdef.SerializeToString()
825    with errors.raise_exception_on_not_ok_status() as status:
826      result._c_func = c_api.TF_FunctionImportFunctionDef(serialized, status)
827    result._extra_inputs = []
828  else:
829    result._definition = fdef
830    # Captured inputs are added as regular inputs to a function when it's
831    # serialized, i.e. any extra inputs from the original function are now
832    # included in `result`._args
833    result._extra_inputs = []
834    result._hash_str = result._create_hash_str(
835        result._definition.signature.input_arg,
836        result._definition.signature.output_arg, result._definition.node_def)
837  # pylint: enable=protected-access
838
839  return result
840
841
842def _from_library(lib):
843  """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
844
845  This method handles assigning the correct gradient functions to each
846  function.
847
848  Args:
849    lib: a FunctionDefLibrary
850
851  Returns:
852    A list of _DefinedFunctions
853
854  Raises:
855    ValueError: `lib` is invalid
856  """
857  if not lib.function and not lib.gradient:
858    return []
859
860  # function name -> FunctionDef proto
861  funcs = {fdef.signature.name: fdef for fdef in lib.function}
862
863  # Validate that all references function names have function defs
864  for g in lib.gradient:
865    if g.function_name not in funcs:
866      raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" %
867                       (g.function_name, str(lib)))
868    if g.gradient_func not in funcs:
869      raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" %
870                       (g.gradient_func, str(lib)))
871
872  # function name -> gradient function name
873  func_to_grad = collections.defaultdict(lambda: None)
874  # gradient function name -> names of functions having that grad function
875  grad_to_funcs = collections.defaultdict(list)
876
877  for gdef in lib.gradient:
878    func_to_grad[gdef.function_name] = gdef.gradient_func
879    grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
880
881  # Start with functions without gradients
882  ready = [
883      fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
884  ]
885  if not ready:
886    raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n"
887                     + str(lib))
888  # function name -> _DefinedFunction
889  initialized = {}
890
891  while ready:
892    fdef = ready.pop()
893    name = fdef.signature.name
894
895    grad = initialized.get(func_to_grad[name])
896    if func_to_grad[name]:
897      assert grad
898    defined_func = _from_definition(fdef, grad_func=grad)
899    initialized[name] = defined_func
900
901    ready.extend(funcs[f] for f in grad_to_funcs[name])
902
903  return initialized.values()
904
905
906def _parse_kwargs_as_attrs(func_name, **kwargs):
907  """Parses **kwargs into a node's attributes."""
908  attrs = {}
909
910  noinline = kwargs.pop("noinline", None)
911  if noinline is not None:
912    attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
913
914  compiled = kwargs.pop("compiled", None)
915  separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None)
916  if compiled is not None:
917    attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled))
918    attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue(
919        b=bool(separate_compiled_gradients))
920    # Forward _XlaScope from enclosing context (if set), otherwise create new.
921    # pylint: disable=protected-access
922    if "_XlaScope" in ops.get_default_graph()._attr_scope_map:
923      attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"]
924    else:
925      attrs["_XlaScope"] = attr_value_pb2.AttrValue(
926          s=("function_%s" % func_name).encode())
927    # pylint: enable=protected-access
928
929  if kwargs:
930    raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
931  return attrs
932
933
934def _get_func_name(func):
935  _, func = tf_decorator.unwrap(func)
936  if callable(func):
937    if tf_inspect.isfunction(func):
938      return func.__name__
939    elif tf_inspect.ismethod(func):
940      return "%s.%s" % (func.__self__.__name__, func.__name__)
941    else:  # Probably a class instance with __call__
942      return type(func)
943  else:
944    raise ValueError("Argument must be callable")
945
946
947def get_extra_vars():
948  """Returns the captured variables by the function.
949
950  Returns:
951    If the default graph is being used to define a function, the
952    returned list of variables are those created inside the function
953    body so far. Otherwise, returns an empty list.
954  """
955  g = ops.get_default_graph()
956  if isinstance(g, _FuncGraph):
957    return g.extra_vars
958  else:
959    return []
960
961
962def get_extra_inputs():
963  """Returns the captured input tensors by the function.
964
965  Returns:
966    If the default graph is being used to define a function, the
967    returned list of tensors are those accessed inside the function body
968    but defined outside the function body so far. Otherwise, returns an
969    empty list.
970  """
971  g = ops.get_default_graph()
972  if isinstance(g, _FuncGraph):
973    return g.extra_inputs
974  else:
975    return []
976
977
978def get_extra_args():
979  """Returns the corresponding function arguments for the captured inputs.
980
981  Returns:
982    If the default graph is being used to define a function, the
983    returned list of place holders are those used inside the function
984    body corresponding those returned by get_extra_inputs(). Otherwise,
985    returns an empty list.
986  """
987  g = ops.get_default_graph()
988  if isinstance(g, _FuncGraph):
989    return g.extra_args
990  else:
991    return []
992
993
994def _type_list_to_str(types):
995  if any([_ not in _DTYPE_TO_STR for _ in types]):
996    raise ValueError("Unsupported dtypes: %s" % types)
997  return "".join([_DTYPE_TO_STR[_] for _ in types])
998
999
1000# NOTE: The list needs to be extended when more data types are added.
1001_DTYPE_TO_STR = {
1002    dtypes.float16: "f16",
1003    dtypes.float32: "f32",
1004    dtypes.float64: "f64",
1005    dtypes.int32: "i32",
1006    dtypes.uint8: "i8",
1007    dtypes.uint16: "u16",
1008    dtypes.uint32: "u32",
1009    dtypes.uint64: "u64",
1010    dtypes.int16: "i16",
1011    dtypes.int8: "i8",
1012    dtypes.string: "s",
1013    dtypes.complex64: "c64",
1014    dtypes.complex128: "c128",
1015    dtypes.int64: "i64",
1016    dtypes.bool: "b",
1017    dtypes.qint8: "qi8",
1018    dtypes.quint8: "qu8",
1019    dtypes.qint16: "qi16",
1020    dtypes.quint16: "qu16",
1021    dtypes.qint32: "qi32",
1022    dtypes.bfloat16: "b16"
1023}
1024