function.py revision 01925fe23a5d8ebf17a1ddfcfaa1503f67e32575
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Python front-end supports for functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import inspect
22import re
23
24from six.moves import xrange  # pylint: disable=redefined-builtin
25
26from tensorflow.core.framework import attr_value_pb2
27from tensorflow.core.framework import function_pb2
28from tensorflow.core.framework import op_def_pb2
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import op_def_registry
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33
34
35def _make_argname_from_tensor_name(name):
36  return re.sub(":0$", "", name).replace(":", "_o")
37
38
39def _tensor_to_argdef(t):
40  arg = op_def_pb2.OpDef.ArgDef()
41  arg.name = _make_argname_from_tensor_name(t.name)
42  arg.type = t.dtype.as_datatype_enum
43  return arg
44
45
46def _get_node_def_attr(op):
47  # pylint: disable=protected-access
48  return op._node_def.attr
49  # pylint: enable=protected-access
50
51
52def _add_input_array(op, start, limit, dtype, func):
53  """Adds a _ListToArray node in the func for op.inputs[start:limit]."""
54  node = function_pb2.FunctionDef.Node()
55  node.op = "_ListToArray"
56  ret_name = op.name + "_L2A_" + str(start)
57  node.ret.extend([ret_name])
58  node.arg.extend([_make_argname_from_tensor_name(x.name)
59                   for x in op.inputs[start:limit]])
60  num = limit - start
61  node.attr["Tin"].CopyFrom(attr_value_pb2.AttrValue(
62      list=attr_value_pb2.AttrValue.ListValue(type=[dtype] * num)))
63  node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype))
64  node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num))
65  func.node.extend([node])
66  return ret_name
67
68
69def _add_output_array(op, start, limit, dtype, func):
70  """Adds a _ArrayToList node in the func for op.outputs[start:limit]."""
71  dtype_proto = attr_value_pb2.AttrValue(type=dtype)
72  # A node converting N*T to list(T)
73  node = function_pb2.FunctionDef.Node()
74  node.op = "_ArrayToList"
75  arg_name = op.name + "_A2L_" + str(start)
76  ret_name = arg_name + "_out"
77  node.ret.append(ret_name)
78  node.arg.append(arg_name)
79  node.attr["T"].CopyFrom(dtype_proto)
80  num = limit - start
81  node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num))
82  node.attr["out_types"].CopyFrom(attr_value_pb2.AttrValue(
83      list=attr_value_pb2.AttrValue.ListValue(type=[dtype] * num)))
84  func.node.extend([node])
85  num = limit - start
86  # Adds an identity node for each element in the array N*T so that
87  # uses of each element can be added easily later. These Identity
88  # will be eliminated before graph execution.
89  for i in xrange(num):
90    node = function_pb2.FunctionDef.Node()
91    node.op = "Identity"
92    node.arg.append(ret_name + ":" + str(i))
93    node.ret.append(_make_argname_from_tensor_name(op.outputs[i].name))
94    node.attr["T"].CopyFrom(dtype_proto)
95    func.node.extend([node])
96  return arg_name
97
98
99def _add_output_list(op, start, limit, dtype_lst, func):
100  """Adds a _ArrayToList node in the func for op.outputs[start:limit]."""
101  ret_name = op.name + "_Lst_" + str(start) + "_" + str(limit)
102  num = limit - start
103  assert len(dtype_lst) == num
104  # Adds an identity node for each element in the array N*T so that
105  # uses of each element can be added easily later. These Identity
106  # will be eliminated before graph execution.
107  for i in xrange(num):
108    node = function_pb2.FunctionDef.Node()
109    node.op = "Identity"
110    node.arg.append(ret_name + ":" + str(i))
111    node.ret.append(_make_argname_from_tensor_name(op.outputs[i].name))
112    node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype_lst[i]))
113    func.node.extend([node])
114  return ret_name
115
116
117def _add_op_node(graph, op, func):
118  """Converts an op to a function def node and add it to `func`."""
119  node = function_pb2.FunctionDef.Node()
120  node.op = op.type
121  # pylint: disable=protected-access
122  if graph._is_function(op.type):
123    op_def = graph._get_function(op.type).signature
124  else:
125    op_def = op_def_registry.get_registered_ops()[op.type]
126  # pylint: enable=protected-access
127  attrs = _get_node_def_attr(op)
128  out_index = 0
129  for arg_def in op_def.output_arg:
130    if arg_def.number_attr:
131      dtype = arg_def.type or attrs[arg_def.type_attr].type
132      num = attrs[arg_def.number_attr].i
133      node.ret.append(_add_output_array(op, out_index, out_index + num, dtype,
134                                        func))
135      out_index += num
136    elif arg_def.type_list_attr:
137      dtype_lst = attrs[arg_def.type_list_attr].list.type
138      num = len(dtype_lst)
139      node.ret.append(_add_output_list(op, out_index, out_index + num,
140                                       dtype_lst, func))
141      out_index += num
142    else:
143      node.ret.append(_make_argname_from_tensor_name(op.outputs[
144          out_index].name))
145      out_index += 1
146  inp_index = 0
147  for arg_def in op_def.input_arg:
148    if arg_def.number_attr:
149      dtype = arg_def.type or attrs[arg_def.type_attr].type
150      num = attrs[arg_def.number_attr].i
151      node.arg.append(_add_input_array(op, inp_index, inp_index + num, dtype,
152                                       func))
153      inp_index += num
154    elif arg_def.type_list_attr:
155      num = len(attrs[arg_def.type_list_attr].list.type)
156      node.arg.extend([_make_argname_from_tensor_name(op.inputs[i].name)
157                       for i in range(inp_index, inp_index + num)])
158      inp_index += num
159    else:
160      node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name))
161      inp_index += 1
162  node.dep.extend([_make_argname_from_tensor_name(x.name)
163                   for x in op.control_inputs])
164  for k, v in _get_node_def_attr(op).items():
165    node.attr[k].CopyFrom(v)
166  func.node.extend([node])
167
168
169# pylint: disable=line-too-long
170def graph_to_function_def(graph, name, inputs, outputs):
171  """Returns `graph` as a `FunctionDef` protocol buffer.
172
173  This method creates a [`FunctionDef`](
174  https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
175  protocol buffer that contains all the ops present in the graph.  The
176  graph effectively becomes the body of the function.
177
178  The arguments `inputs` and `outputs` will be listed as the inputs
179  and outputs tensors of the function.  They must be lists of
180  tensors present in the graph.  The lists can optionally be empty.
181
182  The returned protocol buffer can be passed to the
183  [`Graph.add_function()`](#Graph.add_function) method of a
184  different graph to make it available there.
185
186  Args:
187    graph: GraphDef proto.
188    name: string. The name to use for the function.
189    inputs: List of tensors. Inputs to the function.
190    outputs: List of tensors. Outputs of the function.
191
192  Returns:
193    A FunctionDef protocol buffer.
194  """
195  # pylint: enable=line-too-long
196  func = function_pb2.FunctionDef()
197  func.signature.name = name
198  func.signature.input_arg.extend([_tensor_to_argdef(graph.get_tensor_by_name(
199      i.name)) for i in inputs])
200  func.signature.output_arg.extend([_tensor_to_argdef(graph.get_tensor_by_name(
201      o.name)) for o in outputs])
202  func_arg_placeholders = set([i.name for i in inputs])
203  g = ops.get_default_graph()
204  for op in graph.get_operations():
205    tensor_name = op.values()[0].name
206    if tensor_name not in func_arg_placeholders:
207      _add_op_node(g, op, func)
208  return func
209
210
211def call_function(func_def, *inputs, **kwargs):
212  """Calls the function described by `func_def`.
213
214  This adds a `call` op to the default graph that calls the function described
215  by `func_def` with the tensors listed in `inputs` as arguments.  It returns
216  the outputs of the call, which are one or more tensors.
217
218  `func_def` is a
219  [`FunctionDef`](
220  https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
221  protcol buffer describing a
222  TensorFlow function.  See [`define_function()`](#define_function) for an
223  easy way to create one from a Python function.
224
225  You can pass an optional keyword parameter `name=string` to name the
226  added operation.
227
228  You can pass an optional keyword parameter `noinline=True|False` to instruct
229  the runtime not to inline the function body into the call site.
230
231  `func_def` is automatically added to the function library of the graph if
232  needed.
233
234  Args:
235    func_def: A `FunctionDef` protocol buffer.
236    *inputs: A list of tensors
237    **kwargs: Optional keyword arguments.  Can only contain 'name'.
238
239  Returns:
240    A list of tensors representing the outputs of the call to `func_def`.
241
242  Raises:
243    ValueError: if the arguments are invalid.
244  """
245  name = kwargs.pop("name", None)
246  noinline = kwargs.pop("noinline", None)
247  if noinline is None:
248    attrs = None
249  else:
250    attrs = {}
251    attrs["noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
252  if kwargs:
253    raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
254  func_name = func_def.signature.name
255  with ops.op_scope(inputs, name, func_name) as name:
256    if len(inputs) != len(func_def.signature.input_arg):
257      raise ValueError("Expected number of arguments: %d, received: %d" %
258                       (len(func_def.signature.input_arg), len(inputs)))
259    output_types = [dtypes.DType(x.type) for x in func_def.signature.output_arg]
260    # TODO(touts): Pass compute_shapes as "try if function exists"
261    g = ops.get_default_graph()
262    op = g.create_op(func_name,
263                     list(inputs),
264                     output_types,
265                     name=name,
266                     attrs=attrs,
267                     compute_shapes=False)
268    if op.outputs:
269      if len(op.outputs) == 1:
270        return op.outputs[0]
271      else:
272        return tuple(op.outputs)
273    else:
274      return op
275
276
277def _get_func_name(func):
278  if isinstance(func, _DefinedFunction):
279    return func.name
280  elif callable(func):
281    if inspect.isfunction(func):
282      return func.__name__
283    elif inspect.ismethod(func):
284      return "%s.%s" % (func.__self__.__name__, func.__name__)
285    else:  # Probably a class instance with __call__
286      return type(func)
287  else:
288    raise ValueError("Argument must be callable")
289
290
291def define_function(func, input_types, func_name=None, grad_func=None,
292                    python_grad_func=None):
293  """Creates a `FunctionDef` for a python function.
294
295  `func` is a Python function that receives zero or more tensors and returns at
296  least one tensor.  It should add ops to the default graph the usual way by
297  calling TensorFlow functions such as `tf.constant()`, `tf.matmul()`, etc.
298
299  `input_types` is a dictionary of strings to `tf.Dtype` objects.  Keys are
300  names arguments to `func`.  The value indicate the type of tensor expected
301  by the function.
302
303  The returned `FunctionDef` protocol buffer is also added to the
304  default graph library.  After it has been added you can add calls to
305  the function by passing it to `tf.call_function()`, together with a
306  list of tensors to use as inputs for the function.
307
308  Notes:
309
310  *  `func` is called once, with `placeholder` tensors of the types specified in
311     `input_types` as arguments.
312  *  Values returned by `func` must be tensors and they are recorded as being
313     the output of the function def.
314  *  While `func` is a called, an empty graph is temporarily pushed as the
315     default graph.  All ops added by `func` to that graph are part of the body
316     of the returned function def.
317
318  Example, but also see the [How To on functions](link_needed).
319
320  ```python
321  # A function that receives two tensors x, y and returns their
322  # sum and difference.
323  def my_func(x, y):
324    return x + y, x - y
325
326  # Create a FunctionDef for 'my_func'. (This does not change the default
327  graph.)
328  my_func_def = tf.define_function(my_func, {'x': tf.float32, 'y': tf.float32})
329  # Alternatively:
330  # my_func_def = tf.define_function(my_func, [tf.float32, tf.float32])
331
332  # Build the graph, calling the function.
333  a = tf.constant([1.0])
334  b = tf.constant([2.0])
335  c, d = tf.call_function(my_func_def, a, b, name='mycall')
336  ```
337
338  Args:
339    func: a Python function.
340    input_types: if a dict, keys are the names of the arguments of
341      `func`, values are their expected `tf.DType`. Otherwise,
342      a list of `tf.DType`s.
343    func_name: Pyton string.  If not None, specifies the name to use when
344      creating the Function.  By default, introspection on `func` is used to
345      generate a name.
346    grad_func: If not None, specifies the gradient function. The
347               gradient function must satisify the criterion defined in
348               function.proto:GradientDef.
349    python_grad_func: If not None, specifies the gradient function with the same
350               interface as that expected by `tf.RegisterGradient`. This
351               will be called by tf.gradients to add the gradient ops to the
352               graph. No more than one of {grad_func, python_grad_func} may be
353               specified.
354
355  Returns:
356    A FunctionDef protocol buffer.
357
358  Raises:
359    ValueError: if the arguments are invalid.
360
361  """
362  # TODO(touts): Lift the limitation that func can only receive Tensor args.
363  func_name = func_name or _get_func_name(func)
364  grad_func_name = _get_func_name(grad_func) if grad_func is not None else None
365
366  argspec = inspect.getargspec(func)
367  if argspec.keywords or argspec.defaults:
368    raise ValueError("Functions with argument defaults or keyword "
369                     "arguments are not supported.")
370  if inspect.isfunction(func):
371    if argspec.varargs and (
372        len(argspec.args) > len(input_types)) or not argspec.varargs and (
373            len(argspec.args) != len(input_types)):
374      raise ValueError("The function has fewer arguments "
375                       "than the number of specified input types.")
376    argnames = argspec.args
377  elif inspect.ismethod(func):
378    if argspec.varargs and (
379        len(argspec.args) > 1 + len(input_types)) or not argspec.varargs and (
380            len(argspec.args) != 1 + len(input_types)):
381      raise ValueError("The class function has fewer arguments "
382                       "than the number of specified input types.")
383    # 1st argument is the "class" type.
384    argnames = argspec.args[1:]
385
386  args = []
387  if isinstance(input_types, (list, tuple)):
388    for i in range(len(input_types)):
389      argname = argnames[i] if i < len(argnames) else ("arg%d" % i)
390      argtype = input_types[i]
391      args.append((argname, argtype))
392  else:
393    for name in argnames:
394      if name not in input_types:
395        raise ValueError("Missing type for argument: " + name)
396      args.append((name, input_types[name]))
397
398  # Create the func_def object.
399  temp_graph = ops.Graph()
400  with temp_graph.as_default():
401    # List of placeholders for the function_def.
402    inputs = []
403    # Arglist to call 'func'
404    kwargs = {}
405    for (argname, argtype) in args:
406      argholder = array_ops.placeholder(argtype, name=argname)
407      inputs.append(argholder)
408      kwargs[argname] = argholder
409    # Call func and gather the output tensors.
410    if isinstance(input_types, (list, tuple)):
411      outputs = func(*inputs)
412    else:
413      outputs = func(**kwargs)
414    if not isinstance(outputs, ops.Tensor) and not outputs:
415      raise ValueError("Function must return at least one tensor")
416    # Convenience: if func only returned one value, make it a tuple.
417    if not isinstance(outputs, (list, tuple)):
418      outputs = (outputs,)
419  # Build the FunctionDef
420  func_def = graph_to_function_def(temp_graph, func_name, inputs, outputs)
421  g = ops.get_default_graph()
422  # pylint: disable=protected-access
423  g._add_function(func_def, grad_func_name, python_grad_func=python_grad_func)
424  # pylint: enable=protected-access
425  return func_def
426
427
428class Defun(object):
429  """Decorator used to define TensorFlow functions.
430
431  Use this decorator to make a Python function usable directly as a TensorFlow
432  function.
433
434  The decorated function must add ops to the default graph and return zero or
435  more `Tensor` objects.  Call the decorator with named arguments, one for each
436  argument of the function to decorate, with the expected type of the argument
437  as value.
438
439  For example if the function to decorate accepts two `tf.float32` arguments
440  named `x` and `y`, call the decorator with:
441
442      @Defun(tf.float32, tf.float32)
443      def foo(x, y):
444        ...
445
446  When you call the decorated function it will add `call` ops to the graph.
447
448  Example, but also see the [How To on functions](link_needed).
449
450  ```python
451  # Defining the function.
452  @tf.Defun(tf.float32, tf.float32)
453  def MyFunc(x, y):
454    return x + y, x - y
455
456  # Building the graph.
457  a = tf.Constant([1.0])
458  b = tf.Constant([2.0])
459  c, d = MyFunc(a, b, name='mycall')
460  ```
461
462  @@__init__
463  """
464
465  def __init__(self, *input_type_list, **input_types):
466    """Create a `Defun` decorator.
467
468    Args:
469      *input_type_list: A list of `tf.DType`
470      **input_types: Dict mapping string with `tf.DType`
471        One key for each argument of the function to decorate.
472
473       Note that these optional keyword arguments are also accepted:
474         func_name - (optional).  A python string, the name to use to declare
475           this `Function` in the graph.
476
477         grad_func - (optional).  A function implementing the gradient of the
478           function-to-register.  This is usually a previously
479           `Defun`-registered Python callable.
480
481         python_grad_func - (optional).  A function implementing the gradient of
482           the function python-side. This function must take the current op and
483           the gradients w.r.t. its outputs, and return the gradients w.r.t. the
484           inputs (identical to the interface expected by
485           `tf.RegisterGradient`).
486    """
487    self._func_name = input_types.pop("func_name", None)
488    self._grad_func = input_types.pop("grad_func", None)
489    self._python_grad_func = input_types.pop("python_grad_func", None)
490    assert not input_type_list or not input_types, (
491        "Can't specify both *input_type_list and **input_types")
492    self._input_types = input_types
493    self._input_type_list = input_type_list
494
495  def __call__(self, f):
496    if self._input_types:
497      func_def = define_function(
498          f, self._input_types,
499          func_name=self._func_name, grad_func=self._grad_func,
500          python_grad_func=self._python_grad_func)
501    else:
502      func_def = define_function(
503          f, self._input_type_list,
504          func_name=self._func_name, grad_func=self._grad_func,
505          python_grad_func=self._python_grad_func)
506
507    return _DefinedFunction(definition=func_def)
508
509
510class _DefinedFunction(object):
511  """Class to store the name and definition of the function defined by Defun.
512
513  This object implements a callable interface that runs `call_function`, and
514  provides a `name` property to look up the name of the `Function`.
515
516  An instance of `_DefinedFunction` may be passed to the `grad_func` parameter
517  of `define_function` and `Defun`.
518  """
519
520  def __init__(self, definition):
521    self._definition = definition
522
523  @property
524  def name(self):
525    return self._definition.signature.name
526
527  def __call__(self, *args, **kwargs):
528    return call_function(self._definition, *args, **kwargs)
529