function.py revision 59f1eba5fb94506a205fa2e81145667754739da5
1# Copyright 2015 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Python front-end supports for functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import inspect
22import re
23
24from six.moves import xrange  # pylint: disable=redefined-builtin
25
26from tensorflow.core.framework import attr_value_pb2
27from tensorflow.core.framework import function_pb2
28from tensorflow.core.framework import op_def_pb2
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import op_def_registry
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33
34
35def _make_argname_from_tensor_name(name):
36  return re.sub(":0$", "", name).replace(":", "_o")
37
38
39def _tensor_to_argdef(t):
40  arg = op_def_pb2.OpDef.ArgDef()
41  arg.name = _make_argname_from_tensor_name(t.name)
42  arg.type = t.dtype.as_datatype_enum
43  return arg
44
45
46def _get_node_def_attr(op):
47  # pylint: disable=protected-access
48  return op._node_def.attr
49  # pylint: enable=protected-access
50
51
52def _add_input_array(op, start, limit, dtype, func):
53  """Adds a _ListToArray node in the func for op.inputs[start:limit]."""
54  node = function_pb2.FunctionDef.Node()
55  node.op = "_ListToArray"
56  ret_name = op.name + "_L2A_" + str(start)
57  node.ret.extend([ret_name])
58  node.arg.extend([_make_argname_from_tensor_name(x.name)
59                   for x in op.inputs[start:limit]])
60  num = limit - start
61  node.attr["Tin"].CopyFrom(attr_value_pb2.AttrValue(
62      list=attr_value_pb2.AttrValue.ListValue(type=[dtype] * num)))
63  node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype))
64  node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num))
65  func.node.extend([node])
66  return ret_name
67
68
69def _add_output_array(op, start, limit, dtype, func):
70  """Adds a _ArrayToList node in the func for op.outputs[start:limit]."""
71  dtype_proto = attr_value_pb2.AttrValue(type=dtype)
72  # A node converting N*T to list(T)
73  node = function_pb2.FunctionDef.Node()
74  node.op = "_ArrayToList"
75  arg_name = op.name + "_A2L_" + str(start)
76  ret_name = arg_name + "_out"
77  node.ret.append(ret_name)
78  node.arg.append(arg_name)
79  node.attr["T"].CopyFrom(dtype_proto)
80  num = limit - start
81  node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num))
82  node.attr["out_types"].CopyFrom(attr_value_pb2.AttrValue(
83      list=attr_value_pb2.AttrValue.ListValue(type=[dtype] * num)))
84  func.node.extend([node])
85  num = limit - start
86  # Adds an identity node for each element in the array N*T so that
87  # uses of each element can be added easily later. These Identity
88  # will be eliminated before graph execution.
89  for i in xrange(num):
90    node = function_pb2.FunctionDef.Node()
91    node.op = "Identity"
92    node.arg.append(ret_name + ":" + str(i))
93    node.ret.append(_make_argname_from_tensor_name(op.outputs[i].name))
94    node.attr["T"].CopyFrom(dtype_proto)
95    func.node.extend([node])
96  return arg_name
97
98
99def _add_output_list(op, start, limit, dtype_lst, func):
100  """Adds a _ArrayToList node in the func for op.outputs[start:limit]."""
101  ret_name = op.name + "_Lst_" + str(start) + "_" + str(limit)
102  num = limit - start
103  assert len(dtype_lst) == num
104  # Adds an identity node for each element in the array N*T so that
105  # uses of each element can be added easily later. These Identity
106  # will be eliminated before graph execution.
107  for i in xrange(num):
108    node = function_pb2.FunctionDef.Node()
109    node.op = "Identity"
110    node.arg.append(ret_name + ":" + str(i))
111    node.ret.append(_make_argname_from_tensor_name(op.outputs[i].name))
112    node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype_lst[i]))
113    func.node.extend([node])
114  return ret_name
115
116
117def _add_op_node(graph, op, func):
118  """Converts an op to a function def node and add it to `func`."""
119  node = function_pb2.FunctionDef.Node()
120  node.op = op.type
121  # pylint: disable=protected-access
122  if graph._is_function(op.type):
123    op_def = graph._get_function(op.type).signature
124  else:
125    op_def = op_def_registry.get_registered_ops()[op.type]
126  # pylint: enable=protected-access
127  attrs = _get_node_def_attr(op)
128  out_index = 0
129  for arg_def in op_def.output_arg:
130    if arg_def.number_attr:
131      dtype = arg_def.type or attrs[arg_def.type_attr].type
132      num = attrs[arg_def.number_attr].i
133      node.ret.append(_add_output_array(op, out_index, out_index + num, dtype,
134                                        func))
135      out_index += num
136    elif arg_def.type_list_attr:
137      dtype_lst = attrs[arg_def.type_list_attr].list.type
138      num = len(dtype_lst)
139      node.ret.append(_add_output_list(op, out_index, out_index + num,
140                                       dtype_lst, func))
141      out_index += num
142    else:
143      node.ret.append(_make_argname_from_tensor_name(op.outputs[
144          out_index].name))
145      out_index += 1
146  inp_index = 0
147  for arg_def in op_def.input_arg:
148    if arg_def.number_attr:
149      dtype = arg_def.type or attrs[arg_def.type_attr].type
150      num = attrs[arg_def.number_attr].i
151      node.arg.append(_add_input_array(op, inp_index, inp_index + num, dtype,
152                                       func))
153      inp_index += num
154    elif arg_def.type_list_attr:
155      num = len(attrs[arg_def.type_list_attr].list.type)
156      node.arg.extend([_make_argname_from_tensor_name(op.inputs[i].name)
157                       for i in range(inp_index, inp_index + num)])
158      inp_index += num
159    else:
160      node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name))
161      inp_index += 1
162  node.dep.extend([_make_argname_from_tensor_name(x.name)
163                   for x in op.control_inputs])
164  for k, v in _get_node_def_attr(op).items():
165    node.attr[k].CopyFrom(v)
166  func.node.extend([node])
167
168
169# pylint: disable=line-too-long
170def graph_to_function_def(graph, name, inputs, outputs):
171  """Returns `graph` as a `FunctionDef` protocol buffer.
172
173  This method creates a [`FunctionDef`](
174  https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
175  protocol buffer that contains all the ops present in the graph.  The
176  graph effectively becomes the body of the function.
177
178  The arguments `inputs` and `outputs` will be listed as the inputs
179  and outputs tensors of the function.  They must be lists of
180  tensors present in the graph.  The lists can optionally be empty.
181
182  The returned protocol buffer can be passed to the
183  [`Graph.add_function()`](#Graph.add_function) method of a
184  different graph to make it available there.
185
186  Args:
187    graph: GraphDef proto.
188    name: string. The name to use for the function.
189    inputs: List of tensors. Inputs to the function.
190    outputs: List of tensors. Outputs of the function.
191
192  Returns:
193    A FunctionDef protocol buffer.
194  """
195  # pylint: enable=line-too-long
196  func = function_pb2.FunctionDef()
197  func.signature.name = name
198  func.signature.input_arg.extend([_tensor_to_argdef(graph.get_tensor_by_name(
199      i.name)) for i in inputs])
200  func.signature.output_arg.extend([_tensor_to_argdef(graph.get_tensor_by_name(
201      o.name)) for o in outputs])
202  func_arg_placeholders = set([i.name for i in inputs])
203  g = ops.get_default_graph()
204  for op in graph.get_operations():
205    tensor_name = op.values()[0].name
206    if tensor_name not in func_arg_placeholders:
207      _add_op_node(g, op, func)
208  return func
209
210
211def call_function(func_def, *inputs, **kwargs):
212  """Calls the function described by `func_def`.
213
214  This adds a `call` op to the default graph that calls the function described
215  by `func_def` with the tensors listed in `inputs` as arguments.  It returns
216  the outputs of the call, which are one or more tensors.
217
218  `func_def` is a
219  [`FunctionDef`](
220  https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
221  protcol buffer describing a
222  TensorFlow function.  See [`define_function()`](#define_function) for an
223  easy way to create one from a Python function.
224
225  You can pass an optional keyword parameters `name=string` to name the
226  added operation.
227
228  `func_def` is automatically added to the function library of the graph if
229  needed.
230
231  Args:
232    func_def: A `FunctionDef` protocol buffer.
233    *inputs: A list of tensors
234    **kwargs: Optional keyword arguments.  Can only contain 'name'.
235
236  Returns:
237    A list of tensors representing the outputs of the call to `func_def`.
238
239  Raises:
240    ValueError: if the arguments are invalid.
241  """
242  name = kwargs.pop("name", None)
243  if kwargs:
244    raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
245  func_name = func_def.signature.name
246  with ops.op_scope(inputs, name, func_name) as name:
247    if len(inputs) != len(func_def.signature.input_arg):
248      raise ValueError("Expected number of arguments: %d" %
249                       len(func_def.signature.input_arg))
250    output_types = [dtypes.DType(x.type) for x in func_def.signature.output_arg]
251    # TODO(touts): Pass compute_shapes as "try if function exists"
252    g = ops.get_default_graph()
253    op = g.create_op(func_name,
254                     list(inputs),
255                     output_types,
256                     name=name,
257                     compute_shapes=False)
258    if op.outputs:
259      if len(op.outputs) == 1:
260        return op.outputs[0]
261      else:
262        return tuple(op.outputs)
263    else:
264      return op
265
266
267def define_function(func, input_types):
268  """Creates a `FunctionDef` for a python function.
269
270  `func` is a Python function that receives zero or more tensors and returns at
271  least one tensor.  It should add ops to the default graph the usual way by
272  calling TensorFlow functions such as `tf.constant()`, `tf.matmul()`, etc.
273
274  `input_types` is a dictionary of strings to `tf.Dtype` objects.  Keys are
275  names arguments to `func`.  The value indicate the type of tensor expected
276  by the function.
277
278  The returned `FunctionDef` protocol buffer is also added to the
279  default graph library.  After it has been added you can add calls to
280  the function by passing it to `tf.call_function()`, together with a
281  list of tensors to use as inputs for the function.
282
283  Notes:
284
285  *  `func` is called once, with `placeholder` tensors of the types specified in
286     `input_types` as arguments.
287  *  Values returned by `func` must be tensors and they are recorded as being
288     the output of the function def.
289  *  While `func` is a called, an empty graph is temporarily pushed as the
290     default graph.  All ops added by `func` to that graph are part of the body
291     of the returned function def.
292
293  Example, but also see the [How To on functions](link_needed).
294
295  ```python
296  # A function that receives two tensors x, y and returns their
297  # sum and difference.
298  def my_func(x, y):
299    return x + y, x - y
300
301  # Create a FunctionDef for 'my_func'. (This does not change the default
302  graph.)
303  my_func_def = tf.define_function(my_func, {'x': tf.float32, 'y': tf.float32})
304
305  # Build the graph, calling the function.
306  a = tf.constant([1.0])
307  b = tf.constant([2.0])
308  c, d = tf.call_function(my_func_def, a, b, name='mycall')
309  ```
310
311  Args:
312    func: a Python function.
313    input_types: dict.  Keys are the names of the arguments of `func`, values
314      are their expected `tf.DType`.
315
316  Returns:
317    A FunctionDef protocol buffer.
318
319  Raises:
320    ValueError: if the arguments are invalid.
321
322  """
323  # TODO(touts): Lift the limitation that func can only receive Tensor args.
324  if inspect.isfunction(func):
325    func_name = func.__name__
326  elif inspect.ismethod(func):
327    func_name = func.__self__.__name__ + "." + func.__name__
328  else:
329    raise ValueError("Argument must be a function")
330  argspec = inspect.getargspec(func)
331  if argspec.varargs or argspec.keywords or argspec.defaults:
332    raise ValueError("Only functions with plain arglists are supported.")
333  if inspect.isfunction(func):
334    if len(argspec.args) != len(input_types):
335      raise ValueError("The function must have the same number of arguments "
336                       "as the number of specified input types.")
337    args = argspec.args
338  elif inspect.ismethod(func):
339    if len(argspec.args) != 1 + len(input_types):
340      raise ValueError(
341          "The class function must have the same number of arguments "
342          "as the number of specified input types.")
343    args = argspec.args[1:]  # 1st argument is the "class" type.
344
345  # Create the func_def object.
346  temp_graph = ops.Graph()
347  with temp_graph.as_default():
348    # List of placeholders for the function_def.
349    inputs = []
350    # Arglist to call 'func'
351    kwargs = {}
352    for argname in args:
353      if argname not in input_types:
354        raise ValueError("Missing type for argument: " + argname)
355      argholder = array_ops.placeholder(input_types[argname], name=argname)
356      inputs.append(argholder)
357      kwargs[argname] = argholder
358    # Call func and gather the output tensors.
359    outputs = func(**kwargs)
360    if not outputs:
361      raise ValueError("Function must return at least one tensor")
362    # Convenience: if func only returned one value, make it a tuple.
363    if not isinstance(outputs, (list, tuple)):
364      outputs = (outputs,)
365  # Build the FunctionDef
366  func_def = graph_to_function_def(temp_graph, func_name, inputs, outputs)
367  g = ops.get_default_graph()
368  g._add_function(func_def)  # pylint: disable=protected-access
369  return func_def
370
371
372class Defun(object):
373  """Decorator used to define TensorFlow functions.
374
375  Use this decorator to make a Python function usable directly as a TensorFlow
376  function.
377
378  The decorated function must add ops to the default graph and return zero or
379  more `Tensor` objects.  Call the decorator with named arguments, one for each
380  argument of the function to decorate, with the expected type of the argument
381  as value.
382
383  For example if the function to decorate accepts to `tf.float32` arguments
384  named `x` and `y`, call the decorator with:
385
386      @Defun(x=tf.float32, y=tf.float32)
387      def foo(x, y):
388        ...
389
390  When you call the decorated function it will add `call` ops to the graph.
391
392  Example, but also see the [How To on functions](link_needed).
393
394  ```python
395  # Defining the function.
396  @tf.Defun(x=tf.float32, y=tf.float32)
397  def MyFunc(x, y):
398    return x + y, x - y
399
400  # Building the graph.
401  a = tf.Constant([1.0])
402  b = tf.Constant([2.0])
403  c, d = MyFunc(a, b, name='mycall')
404  ```
405
406  @@__init__
407  """
408
409  def __init__(self, **input_types):
410    """Create a `Defun` decorator.
411
412    Args:
413      **input_types: Dict mapping string with `tf.DType`
414        One key for each argument of the function to decorate.
415    """
416    self._input_types = input_types
417
418  def __call__(self, f):
419    func_def = define_function(f, self._input_types)
420    return lambda *args, **kwargs: call_function(func_def, *args, **kwargs)
421