function.py revision 59f1eba5fb94506a205fa2e81145667754739da5
1# Copyright 2015 Google Inc. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Python front-end supports for functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import inspect 22import re 23 24from six.moves import xrange # pylint: disable=redefined-builtin 25 26from tensorflow.core.framework import attr_value_pb2 27from tensorflow.core.framework import function_pb2 28from tensorflow.core.framework import op_def_pb2 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import op_def_registry 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33 34 35def _make_argname_from_tensor_name(name): 36 return re.sub(":0$", "", name).replace(":", "_o") 37 38 39def _tensor_to_argdef(t): 40 arg = op_def_pb2.OpDef.ArgDef() 41 arg.name = _make_argname_from_tensor_name(t.name) 42 arg.type = t.dtype.as_datatype_enum 43 return arg 44 45 46def _get_node_def_attr(op): 47 # pylint: disable=protected-access 48 return op._node_def.attr 49 # pylint: enable=protected-access 50 51 52def _add_input_array(op, start, limit, dtype, func): 53 """Adds a _ListToArray node in the func for op.inputs[start:limit].""" 54 node = function_pb2.FunctionDef.Node() 55 node.op = "_ListToArray" 56 ret_name = op.name + "_L2A_" + str(start) 57 node.ret.extend([ret_name]) 58 node.arg.extend([_make_argname_from_tensor_name(x.name) 59 for x in op.inputs[start:limit]]) 60 num = limit - start 61 node.attr["Tin"].CopyFrom(attr_value_pb2.AttrValue( 62 list=attr_value_pb2.AttrValue.ListValue(type=[dtype] * num))) 63 node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype)) 64 node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num)) 65 func.node.extend([node]) 66 return ret_name 67 68 69def _add_output_array(op, start, limit, dtype, func): 70 """Adds a _ArrayToList node in the func for op.outputs[start:limit].""" 71 dtype_proto = attr_value_pb2.AttrValue(type=dtype) 72 # A node converting N*T to list(T) 73 node = function_pb2.FunctionDef.Node() 74 node.op = "_ArrayToList" 75 arg_name = op.name + "_A2L_" + str(start) 76 ret_name = arg_name + "_out" 77 node.ret.append(ret_name) 78 node.arg.append(arg_name) 79 node.attr["T"].CopyFrom(dtype_proto) 80 num = limit - start 81 node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num)) 82 node.attr["out_types"].CopyFrom(attr_value_pb2.AttrValue( 83 list=attr_value_pb2.AttrValue.ListValue(type=[dtype] * num))) 84 func.node.extend([node]) 85 num = limit - start 86 # Adds an identity node for each element in the array N*T so that 87 # uses of each element can be added easily later. These Identity 88 # will be eliminated before graph execution. 89 for i in xrange(num): 90 node = function_pb2.FunctionDef.Node() 91 node.op = "Identity" 92 node.arg.append(ret_name + ":" + str(i)) 93 node.ret.append(_make_argname_from_tensor_name(op.outputs[i].name)) 94 node.attr["T"].CopyFrom(dtype_proto) 95 func.node.extend([node]) 96 return arg_name 97 98 99def _add_output_list(op, start, limit, dtype_lst, func): 100 """Adds a _ArrayToList node in the func for op.outputs[start:limit].""" 101 ret_name = op.name + "_Lst_" + str(start) + "_" + str(limit) 102 num = limit - start 103 assert len(dtype_lst) == num 104 # Adds an identity node for each element in the array N*T so that 105 # uses of each element can be added easily later. These Identity 106 # will be eliminated before graph execution. 107 for i in xrange(num): 108 node = function_pb2.FunctionDef.Node() 109 node.op = "Identity" 110 node.arg.append(ret_name + ":" + str(i)) 111 node.ret.append(_make_argname_from_tensor_name(op.outputs[i].name)) 112 node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype_lst[i])) 113 func.node.extend([node]) 114 return ret_name 115 116 117def _add_op_node(graph, op, func): 118 """Converts an op to a function def node and add it to `func`.""" 119 node = function_pb2.FunctionDef.Node() 120 node.op = op.type 121 # pylint: disable=protected-access 122 if graph._is_function(op.type): 123 op_def = graph._get_function(op.type).signature 124 else: 125 op_def = op_def_registry.get_registered_ops()[op.type] 126 # pylint: enable=protected-access 127 attrs = _get_node_def_attr(op) 128 out_index = 0 129 for arg_def in op_def.output_arg: 130 if arg_def.number_attr: 131 dtype = arg_def.type or attrs[arg_def.type_attr].type 132 num = attrs[arg_def.number_attr].i 133 node.ret.append(_add_output_array(op, out_index, out_index + num, dtype, 134 func)) 135 out_index += num 136 elif arg_def.type_list_attr: 137 dtype_lst = attrs[arg_def.type_list_attr].list.type 138 num = len(dtype_lst) 139 node.ret.append(_add_output_list(op, out_index, out_index + num, 140 dtype_lst, func)) 141 out_index += num 142 else: 143 node.ret.append(_make_argname_from_tensor_name(op.outputs[ 144 out_index].name)) 145 out_index += 1 146 inp_index = 0 147 for arg_def in op_def.input_arg: 148 if arg_def.number_attr: 149 dtype = arg_def.type or attrs[arg_def.type_attr].type 150 num = attrs[arg_def.number_attr].i 151 node.arg.append(_add_input_array(op, inp_index, inp_index + num, dtype, 152 func)) 153 inp_index += num 154 elif arg_def.type_list_attr: 155 num = len(attrs[arg_def.type_list_attr].list.type) 156 node.arg.extend([_make_argname_from_tensor_name(op.inputs[i].name) 157 for i in range(inp_index, inp_index + num)]) 158 inp_index += num 159 else: 160 node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name)) 161 inp_index += 1 162 node.dep.extend([_make_argname_from_tensor_name(x.name) 163 for x in op.control_inputs]) 164 for k, v in _get_node_def_attr(op).items(): 165 node.attr[k].CopyFrom(v) 166 func.node.extend([node]) 167 168 169# pylint: disable=line-too-long 170def graph_to_function_def(graph, name, inputs, outputs): 171 """Returns `graph` as a `FunctionDef` protocol buffer. 172 173 This method creates a [`FunctionDef`]( 174 https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) 175 protocol buffer that contains all the ops present in the graph. The 176 graph effectively becomes the body of the function. 177 178 The arguments `inputs` and `outputs` will be listed as the inputs 179 and outputs tensors of the function. They must be lists of 180 tensors present in the graph. The lists can optionally be empty. 181 182 The returned protocol buffer can be passed to the 183 [`Graph.add_function()`](#Graph.add_function) method of a 184 different graph to make it available there. 185 186 Args: 187 graph: GraphDef proto. 188 name: string. The name to use for the function. 189 inputs: List of tensors. Inputs to the function. 190 outputs: List of tensors. Outputs of the function. 191 192 Returns: 193 A FunctionDef protocol buffer. 194 """ 195 # pylint: enable=line-too-long 196 func = function_pb2.FunctionDef() 197 func.signature.name = name 198 func.signature.input_arg.extend([_tensor_to_argdef(graph.get_tensor_by_name( 199 i.name)) for i in inputs]) 200 func.signature.output_arg.extend([_tensor_to_argdef(graph.get_tensor_by_name( 201 o.name)) for o in outputs]) 202 func_arg_placeholders = set([i.name for i in inputs]) 203 g = ops.get_default_graph() 204 for op in graph.get_operations(): 205 tensor_name = op.values()[0].name 206 if tensor_name not in func_arg_placeholders: 207 _add_op_node(g, op, func) 208 return func 209 210 211def call_function(func_def, *inputs, **kwargs): 212 """Calls the function described by `func_def`. 213 214 This adds a `call` op to the default graph that calls the function described 215 by `func_def` with the tensors listed in `inputs` as arguments. It returns 216 the outputs of the call, which are one or more tensors. 217 218 `func_def` is a 219 [`FunctionDef`]( 220 https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) 221 protcol buffer describing a 222 TensorFlow function. See [`define_function()`](#define_function) for an 223 easy way to create one from a Python function. 224 225 You can pass an optional keyword parameters `name=string` to name the 226 added operation. 227 228 `func_def` is automatically added to the function library of the graph if 229 needed. 230 231 Args: 232 func_def: A `FunctionDef` protocol buffer. 233 *inputs: A list of tensors 234 **kwargs: Optional keyword arguments. Can only contain 'name'. 235 236 Returns: 237 A list of tensors representing the outputs of the call to `func_def`. 238 239 Raises: 240 ValueError: if the arguments are invalid. 241 """ 242 name = kwargs.pop("name", None) 243 if kwargs: 244 raise ValueError("Unknown keyword arguments: %s" % kwargs.keys()) 245 func_name = func_def.signature.name 246 with ops.op_scope(inputs, name, func_name) as name: 247 if len(inputs) != len(func_def.signature.input_arg): 248 raise ValueError("Expected number of arguments: %d" % 249 len(func_def.signature.input_arg)) 250 output_types = [dtypes.DType(x.type) for x in func_def.signature.output_arg] 251 # TODO(touts): Pass compute_shapes as "try if function exists" 252 g = ops.get_default_graph() 253 op = g.create_op(func_name, 254 list(inputs), 255 output_types, 256 name=name, 257 compute_shapes=False) 258 if op.outputs: 259 if len(op.outputs) == 1: 260 return op.outputs[0] 261 else: 262 return tuple(op.outputs) 263 else: 264 return op 265 266 267def define_function(func, input_types): 268 """Creates a `FunctionDef` for a python function. 269 270 `func` is a Python function that receives zero or more tensors and returns at 271 least one tensor. It should add ops to the default graph the usual way by 272 calling TensorFlow functions such as `tf.constant()`, `tf.matmul()`, etc. 273 274 `input_types` is a dictionary of strings to `tf.Dtype` objects. Keys are 275 names arguments to `func`. The value indicate the type of tensor expected 276 by the function. 277 278 The returned `FunctionDef` protocol buffer is also added to the 279 default graph library. After it has been added you can add calls to 280 the function by passing it to `tf.call_function()`, together with a 281 list of tensors to use as inputs for the function. 282 283 Notes: 284 285 * `func` is called once, with `placeholder` tensors of the types specified in 286 `input_types` as arguments. 287 * Values returned by `func` must be tensors and they are recorded as being 288 the output of the function def. 289 * While `func` is a called, an empty graph is temporarily pushed as the 290 default graph. All ops added by `func` to that graph are part of the body 291 of the returned function def. 292 293 Example, but also see the [How To on functions](link_needed). 294 295 ```python 296 # A function that receives two tensors x, y and returns their 297 # sum and difference. 298 def my_func(x, y): 299 return x + y, x - y 300 301 # Create a FunctionDef for 'my_func'. (This does not change the default 302 graph.) 303 my_func_def = tf.define_function(my_func, {'x': tf.float32, 'y': tf.float32}) 304 305 # Build the graph, calling the function. 306 a = tf.constant([1.0]) 307 b = tf.constant([2.0]) 308 c, d = tf.call_function(my_func_def, a, b, name='mycall') 309 ``` 310 311 Args: 312 func: a Python function. 313 input_types: dict. Keys are the names of the arguments of `func`, values 314 are their expected `tf.DType`. 315 316 Returns: 317 A FunctionDef protocol buffer. 318 319 Raises: 320 ValueError: if the arguments are invalid. 321 322 """ 323 # TODO(touts): Lift the limitation that func can only receive Tensor args. 324 if inspect.isfunction(func): 325 func_name = func.__name__ 326 elif inspect.ismethod(func): 327 func_name = func.__self__.__name__ + "." + func.__name__ 328 else: 329 raise ValueError("Argument must be a function") 330 argspec = inspect.getargspec(func) 331 if argspec.varargs or argspec.keywords or argspec.defaults: 332 raise ValueError("Only functions with plain arglists are supported.") 333 if inspect.isfunction(func): 334 if len(argspec.args) != len(input_types): 335 raise ValueError("The function must have the same number of arguments " 336 "as the number of specified input types.") 337 args = argspec.args 338 elif inspect.ismethod(func): 339 if len(argspec.args) != 1 + len(input_types): 340 raise ValueError( 341 "The class function must have the same number of arguments " 342 "as the number of specified input types.") 343 args = argspec.args[1:] # 1st argument is the "class" type. 344 345 # Create the func_def object. 346 temp_graph = ops.Graph() 347 with temp_graph.as_default(): 348 # List of placeholders for the function_def. 349 inputs = [] 350 # Arglist to call 'func' 351 kwargs = {} 352 for argname in args: 353 if argname not in input_types: 354 raise ValueError("Missing type for argument: " + argname) 355 argholder = array_ops.placeholder(input_types[argname], name=argname) 356 inputs.append(argholder) 357 kwargs[argname] = argholder 358 # Call func and gather the output tensors. 359 outputs = func(**kwargs) 360 if not outputs: 361 raise ValueError("Function must return at least one tensor") 362 # Convenience: if func only returned one value, make it a tuple. 363 if not isinstance(outputs, (list, tuple)): 364 outputs = (outputs,) 365 # Build the FunctionDef 366 func_def = graph_to_function_def(temp_graph, func_name, inputs, outputs) 367 g = ops.get_default_graph() 368 g._add_function(func_def) # pylint: disable=protected-access 369 return func_def 370 371 372class Defun(object): 373 """Decorator used to define TensorFlow functions. 374 375 Use this decorator to make a Python function usable directly as a TensorFlow 376 function. 377 378 The decorated function must add ops to the default graph and return zero or 379 more `Tensor` objects. Call the decorator with named arguments, one for each 380 argument of the function to decorate, with the expected type of the argument 381 as value. 382 383 For example if the function to decorate accepts to `tf.float32` arguments 384 named `x` and `y`, call the decorator with: 385 386 @Defun(x=tf.float32, y=tf.float32) 387 def foo(x, y): 388 ... 389 390 When you call the decorated function it will add `call` ops to the graph. 391 392 Example, but also see the [How To on functions](link_needed). 393 394 ```python 395 # Defining the function. 396 @tf.Defun(x=tf.float32, y=tf.float32) 397 def MyFunc(x, y): 398 return x + y, x - y 399 400 # Building the graph. 401 a = tf.Constant([1.0]) 402 b = tf.Constant([2.0]) 403 c, d = MyFunc(a, b, name='mycall') 404 ``` 405 406 @@__init__ 407 """ 408 409 def __init__(self, **input_types): 410 """Create a `Defun` decorator. 411 412 Args: 413 **input_types: Dict mapping string with `tf.DType` 414 One key for each argument of the function to decorate. 415 """ 416 self._input_types = input_types 417 418 def __call__(self, f): 419 func_def = define_function(f, self._input_types) 420 return lambda *args, **kwargs: call_function(func_def, *args, **kwargs) 421