function.py revision 01925fe23a5d8ebf17a1ddfcfaa1503f67e32575
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Python front-end supports for functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import inspect 22import re 23 24from six.moves import xrange # pylint: disable=redefined-builtin 25 26from tensorflow.core.framework import attr_value_pb2 27from tensorflow.core.framework import function_pb2 28from tensorflow.core.framework import op_def_pb2 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import op_def_registry 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33 34 35def _make_argname_from_tensor_name(name): 36 return re.sub(":0$", "", name).replace(":", "_o") 37 38 39def _tensor_to_argdef(t): 40 arg = op_def_pb2.OpDef.ArgDef() 41 arg.name = _make_argname_from_tensor_name(t.name) 42 arg.type = t.dtype.as_datatype_enum 43 return arg 44 45 46def _get_node_def_attr(op): 47 # pylint: disable=protected-access 48 return op._node_def.attr 49 # pylint: enable=protected-access 50 51 52def _add_input_array(op, start, limit, dtype, func): 53 """Adds a _ListToArray node in the func for op.inputs[start:limit].""" 54 node = function_pb2.FunctionDef.Node() 55 node.op = "_ListToArray" 56 ret_name = op.name + "_L2A_" + str(start) 57 node.ret.extend([ret_name]) 58 node.arg.extend([_make_argname_from_tensor_name(x.name) 59 for x in op.inputs[start:limit]]) 60 num = limit - start 61 node.attr["Tin"].CopyFrom(attr_value_pb2.AttrValue( 62 list=attr_value_pb2.AttrValue.ListValue(type=[dtype] * num))) 63 node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype)) 64 node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num)) 65 func.node.extend([node]) 66 return ret_name 67 68 69def _add_output_array(op, start, limit, dtype, func): 70 """Adds a _ArrayToList node in the func for op.outputs[start:limit].""" 71 dtype_proto = attr_value_pb2.AttrValue(type=dtype) 72 # A node converting N*T to list(T) 73 node = function_pb2.FunctionDef.Node() 74 node.op = "_ArrayToList" 75 arg_name = op.name + "_A2L_" + str(start) 76 ret_name = arg_name + "_out" 77 node.ret.append(ret_name) 78 node.arg.append(arg_name) 79 node.attr["T"].CopyFrom(dtype_proto) 80 num = limit - start 81 node.attr["N"].CopyFrom(attr_value_pb2.AttrValue(i=num)) 82 node.attr["out_types"].CopyFrom(attr_value_pb2.AttrValue( 83 list=attr_value_pb2.AttrValue.ListValue(type=[dtype] * num))) 84 func.node.extend([node]) 85 num = limit - start 86 # Adds an identity node for each element in the array N*T so that 87 # uses of each element can be added easily later. These Identity 88 # will be eliminated before graph execution. 89 for i in xrange(num): 90 node = function_pb2.FunctionDef.Node() 91 node.op = "Identity" 92 node.arg.append(ret_name + ":" + str(i)) 93 node.ret.append(_make_argname_from_tensor_name(op.outputs[i].name)) 94 node.attr["T"].CopyFrom(dtype_proto) 95 func.node.extend([node]) 96 return arg_name 97 98 99def _add_output_list(op, start, limit, dtype_lst, func): 100 """Adds a _ArrayToList node in the func for op.outputs[start:limit].""" 101 ret_name = op.name + "_Lst_" + str(start) + "_" + str(limit) 102 num = limit - start 103 assert len(dtype_lst) == num 104 # Adds an identity node for each element in the array N*T so that 105 # uses of each element can be added easily later. These Identity 106 # will be eliminated before graph execution. 107 for i in xrange(num): 108 node = function_pb2.FunctionDef.Node() 109 node.op = "Identity" 110 node.arg.append(ret_name + ":" + str(i)) 111 node.ret.append(_make_argname_from_tensor_name(op.outputs[i].name)) 112 node.attr["T"].CopyFrom(attr_value_pb2.AttrValue(type=dtype_lst[i])) 113 func.node.extend([node]) 114 return ret_name 115 116 117def _add_op_node(graph, op, func): 118 """Converts an op to a function def node and add it to `func`.""" 119 node = function_pb2.FunctionDef.Node() 120 node.op = op.type 121 # pylint: disable=protected-access 122 if graph._is_function(op.type): 123 op_def = graph._get_function(op.type).signature 124 else: 125 op_def = op_def_registry.get_registered_ops()[op.type] 126 # pylint: enable=protected-access 127 attrs = _get_node_def_attr(op) 128 out_index = 0 129 for arg_def in op_def.output_arg: 130 if arg_def.number_attr: 131 dtype = arg_def.type or attrs[arg_def.type_attr].type 132 num = attrs[arg_def.number_attr].i 133 node.ret.append(_add_output_array(op, out_index, out_index + num, dtype, 134 func)) 135 out_index += num 136 elif arg_def.type_list_attr: 137 dtype_lst = attrs[arg_def.type_list_attr].list.type 138 num = len(dtype_lst) 139 node.ret.append(_add_output_list(op, out_index, out_index + num, 140 dtype_lst, func)) 141 out_index += num 142 else: 143 node.ret.append(_make_argname_from_tensor_name(op.outputs[ 144 out_index].name)) 145 out_index += 1 146 inp_index = 0 147 for arg_def in op_def.input_arg: 148 if arg_def.number_attr: 149 dtype = arg_def.type or attrs[arg_def.type_attr].type 150 num = attrs[arg_def.number_attr].i 151 node.arg.append(_add_input_array(op, inp_index, inp_index + num, dtype, 152 func)) 153 inp_index += num 154 elif arg_def.type_list_attr: 155 num = len(attrs[arg_def.type_list_attr].list.type) 156 node.arg.extend([_make_argname_from_tensor_name(op.inputs[i].name) 157 for i in range(inp_index, inp_index + num)]) 158 inp_index += num 159 else: 160 node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name)) 161 inp_index += 1 162 node.dep.extend([_make_argname_from_tensor_name(x.name) 163 for x in op.control_inputs]) 164 for k, v in _get_node_def_attr(op).items(): 165 node.attr[k].CopyFrom(v) 166 func.node.extend([node]) 167 168 169# pylint: disable=line-too-long 170def graph_to_function_def(graph, name, inputs, outputs): 171 """Returns `graph` as a `FunctionDef` protocol buffer. 172 173 This method creates a [`FunctionDef`]( 174 https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) 175 protocol buffer that contains all the ops present in the graph. The 176 graph effectively becomes the body of the function. 177 178 The arguments `inputs` and `outputs` will be listed as the inputs 179 and outputs tensors of the function. They must be lists of 180 tensors present in the graph. The lists can optionally be empty. 181 182 The returned protocol buffer can be passed to the 183 [`Graph.add_function()`](#Graph.add_function) method of a 184 different graph to make it available there. 185 186 Args: 187 graph: GraphDef proto. 188 name: string. The name to use for the function. 189 inputs: List of tensors. Inputs to the function. 190 outputs: List of tensors. Outputs of the function. 191 192 Returns: 193 A FunctionDef protocol buffer. 194 """ 195 # pylint: enable=line-too-long 196 func = function_pb2.FunctionDef() 197 func.signature.name = name 198 func.signature.input_arg.extend([_tensor_to_argdef(graph.get_tensor_by_name( 199 i.name)) for i in inputs]) 200 func.signature.output_arg.extend([_tensor_to_argdef(graph.get_tensor_by_name( 201 o.name)) for o in outputs]) 202 func_arg_placeholders = set([i.name for i in inputs]) 203 g = ops.get_default_graph() 204 for op in graph.get_operations(): 205 tensor_name = op.values()[0].name 206 if tensor_name not in func_arg_placeholders: 207 _add_op_node(g, op, func) 208 return func 209 210 211def call_function(func_def, *inputs, **kwargs): 212 """Calls the function described by `func_def`. 213 214 This adds a `call` op to the default graph that calls the function described 215 by `func_def` with the tensors listed in `inputs` as arguments. It returns 216 the outputs of the call, which are one or more tensors. 217 218 `func_def` is a 219 [`FunctionDef`]( 220 https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) 221 protcol buffer describing a 222 TensorFlow function. See [`define_function()`](#define_function) for an 223 easy way to create one from a Python function. 224 225 You can pass an optional keyword parameter `name=string` to name the 226 added operation. 227 228 You can pass an optional keyword parameter `noinline=True|False` to instruct 229 the runtime not to inline the function body into the call site. 230 231 `func_def` is automatically added to the function library of the graph if 232 needed. 233 234 Args: 235 func_def: A `FunctionDef` protocol buffer. 236 *inputs: A list of tensors 237 **kwargs: Optional keyword arguments. Can only contain 'name'. 238 239 Returns: 240 A list of tensors representing the outputs of the call to `func_def`. 241 242 Raises: 243 ValueError: if the arguments are invalid. 244 """ 245 name = kwargs.pop("name", None) 246 noinline = kwargs.pop("noinline", None) 247 if noinline is None: 248 attrs = None 249 else: 250 attrs = {} 251 attrs["noinline"] = attr_value_pb2.AttrValue(b=bool(noinline)) 252 if kwargs: 253 raise ValueError("Unknown keyword arguments: %s" % kwargs.keys()) 254 func_name = func_def.signature.name 255 with ops.op_scope(inputs, name, func_name) as name: 256 if len(inputs) != len(func_def.signature.input_arg): 257 raise ValueError("Expected number of arguments: %d, received: %d" % 258 (len(func_def.signature.input_arg), len(inputs))) 259 output_types = [dtypes.DType(x.type) for x in func_def.signature.output_arg] 260 # TODO(touts): Pass compute_shapes as "try if function exists" 261 g = ops.get_default_graph() 262 op = g.create_op(func_name, 263 list(inputs), 264 output_types, 265 name=name, 266 attrs=attrs, 267 compute_shapes=False) 268 if op.outputs: 269 if len(op.outputs) == 1: 270 return op.outputs[0] 271 else: 272 return tuple(op.outputs) 273 else: 274 return op 275 276 277def _get_func_name(func): 278 if isinstance(func, _DefinedFunction): 279 return func.name 280 elif callable(func): 281 if inspect.isfunction(func): 282 return func.__name__ 283 elif inspect.ismethod(func): 284 return "%s.%s" % (func.__self__.__name__, func.__name__) 285 else: # Probably a class instance with __call__ 286 return type(func) 287 else: 288 raise ValueError("Argument must be callable") 289 290 291def define_function(func, input_types, func_name=None, grad_func=None, 292 python_grad_func=None): 293 """Creates a `FunctionDef` for a python function. 294 295 `func` is a Python function that receives zero or more tensors and returns at 296 least one tensor. It should add ops to the default graph the usual way by 297 calling TensorFlow functions such as `tf.constant()`, `tf.matmul()`, etc. 298 299 `input_types` is a dictionary of strings to `tf.Dtype` objects. Keys are 300 names arguments to `func`. The value indicate the type of tensor expected 301 by the function. 302 303 The returned `FunctionDef` protocol buffer is also added to the 304 default graph library. After it has been added you can add calls to 305 the function by passing it to `tf.call_function()`, together with a 306 list of tensors to use as inputs for the function. 307 308 Notes: 309 310 * `func` is called once, with `placeholder` tensors of the types specified in 311 `input_types` as arguments. 312 * Values returned by `func` must be tensors and they are recorded as being 313 the output of the function def. 314 * While `func` is a called, an empty graph is temporarily pushed as the 315 default graph. All ops added by `func` to that graph are part of the body 316 of the returned function def. 317 318 Example, but also see the [How To on functions](link_needed). 319 320 ```python 321 # A function that receives two tensors x, y and returns their 322 # sum and difference. 323 def my_func(x, y): 324 return x + y, x - y 325 326 # Create a FunctionDef for 'my_func'. (This does not change the default 327 graph.) 328 my_func_def = tf.define_function(my_func, {'x': tf.float32, 'y': tf.float32}) 329 # Alternatively: 330 # my_func_def = tf.define_function(my_func, [tf.float32, tf.float32]) 331 332 # Build the graph, calling the function. 333 a = tf.constant([1.0]) 334 b = tf.constant([2.0]) 335 c, d = tf.call_function(my_func_def, a, b, name='mycall') 336 ``` 337 338 Args: 339 func: a Python function. 340 input_types: if a dict, keys are the names of the arguments of 341 `func`, values are their expected `tf.DType`. Otherwise, 342 a list of `tf.DType`s. 343 func_name: Pyton string. If not None, specifies the name to use when 344 creating the Function. By default, introspection on `func` is used to 345 generate a name. 346 grad_func: If not None, specifies the gradient function. The 347 gradient function must satisify the criterion defined in 348 function.proto:GradientDef. 349 python_grad_func: If not None, specifies the gradient function with the same 350 interface as that expected by `tf.RegisterGradient`. This 351 will be called by tf.gradients to add the gradient ops to the 352 graph. No more than one of {grad_func, python_grad_func} may be 353 specified. 354 355 Returns: 356 A FunctionDef protocol buffer. 357 358 Raises: 359 ValueError: if the arguments are invalid. 360 361 """ 362 # TODO(touts): Lift the limitation that func can only receive Tensor args. 363 func_name = func_name or _get_func_name(func) 364 grad_func_name = _get_func_name(grad_func) if grad_func is not None else None 365 366 argspec = inspect.getargspec(func) 367 if argspec.keywords or argspec.defaults: 368 raise ValueError("Functions with argument defaults or keyword " 369 "arguments are not supported.") 370 if inspect.isfunction(func): 371 if argspec.varargs and ( 372 len(argspec.args) > len(input_types)) or not argspec.varargs and ( 373 len(argspec.args) != len(input_types)): 374 raise ValueError("The function has fewer arguments " 375 "than the number of specified input types.") 376 argnames = argspec.args 377 elif inspect.ismethod(func): 378 if argspec.varargs and ( 379 len(argspec.args) > 1 + len(input_types)) or not argspec.varargs and ( 380 len(argspec.args) != 1 + len(input_types)): 381 raise ValueError("The class function has fewer arguments " 382 "than the number of specified input types.") 383 # 1st argument is the "class" type. 384 argnames = argspec.args[1:] 385 386 args = [] 387 if isinstance(input_types, (list, tuple)): 388 for i in range(len(input_types)): 389 argname = argnames[i] if i < len(argnames) else ("arg%d" % i) 390 argtype = input_types[i] 391 args.append((argname, argtype)) 392 else: 393 for name in argnames: 394 if name not in input_types: 395 raise ValueError("Missing type for argument: " + name) 396 args.append((name, input_types[name])) 397 398 # Create the func_def object. 399 temp_graph = ops.Graph() 400 with temp_graph.as_default(): 401 # List of placeholders for the function_def. 402 inputs = [] 403 # Arglist to call 'func' 404 kwargs = {} 405 for (argname, argtype) in args: 406 argholder = array_ops.placeholder(argtype, name=argname) 407 inputs.append(argholder) 408 kwargs[argname] = argholder 409 # Call func and gather the output tensors. 410 if isinstance(input_types, (list, tuple)): 411 outputs = func(*inputs) 412 else: 413 outputs = func(**kwargs) 414 if not isinstance(outputs, ops.Tensor) and not outputs: 415 raise ValueError("Function must return at least one tensor") 416 # Convenience: if func only returned one value, make it a tuple. 417 if not isinstance(outputs, (list, tuple)): 418 outputs = (outputs,) 419 # Build the FunctionDef 420 func_def = graph_to_function_def(temp_graph, func_name, inputs, outputs) 421 g = ops.get_default_graph() 422 # pylint: disable=protected-access 423 g._add_function(func_def, grad_func_name, python_grad_func=python_grad_func) 424 # pylint: enable=protected-access 425 return func_def 426 427 428class Defun(object): 429 """Decorator used to define TensorFlow functions. 430 431 Use this decorator to make a Python function usable directly as a TensorFlow 432 function. 433 434 The decorated function must add ops to the default graph and return zero or 435 more `Tensor` objects. Call the decorator with named arguments, one for each 436 argument of the function to decorate, with the expected type of the argument 437 as value. 438 439 For example if the function to decorate accepts two `tf.float32` arguments 440 named `x` and `y`, call the decorator with: 441 442 @Defun(tf.float32, tf.float32) 443 def foo(x, y): 444 ... 445 446 When you call the decorated function it will add `call` ops to the graph. 447 448 Example, but also see the [How To on functions](link_needed). 449 450 ```python 451 # Defining the function. 452 @tf.Defun(tf.float32, tf.float32) 453 def MyFunc(x, y): 454 return x + y, x - y 455 456 # Building the graph. 457 a = tf.Constant([1.0]) 458 b = tf.Constant([2.0]) 459 c, d = MyFunc(a, b, name='mycall') 460 ``` 461 462 @@__init__ 463 """ 464 465 def __init__(self, *input_type_list, **input_types): 466 """Create a `Defun` decorator. 467 468 Args: 469 *input_type_list: A list of `tf.DType` 470 **input_types: Dict mapping string with `tf.DType` 471 One key for each argument of the function to decorate. 472 473 Note that these optional keyword arguments are also accepted: 474 func_name - (optional). A python string, the name to use to declare 475 this `Function` in the graph. 476 477 grad_func - (optional). A function implementing the gradient of the 478 function-to-register. This is usually a previously 479 `Defun`-registered Python callable. 480 481 python_grad_func - (optional). A function implementing the gradient of 482 the function python-side. This function must take the current op and 483 the gradients w.r.t. its outputs, and return the gradients w.r.t. the 484 inputs (identical to the interface expected by 485 `tf.RegisterGradient`). 486 """ 487 self._func_name = input_types.pop("func_name", None) 488 self._grad_func = input_types.pop("grad_func", None) 489 self._python_grad_func = input_types.pop("python_grad_func", None) 490 assert not input_type_list or not input_types, ( 491 "Can't specify both *input_type_list and **input_types") 492 self._input_types = input_types 493 self._input_type_list = input_type_list 494 495 def __call__(self, f): 496 if self._input_types: 497 func_def = define_function( 498 f, self._input_types, 499 func_name=self._func_name, grad_func=self._grad_func, 500 python_grad_func=self._python_grad_func) 501 else: 502 func_def = define_function( 503 f, self._input_type_list, 504 func_name=self._func_name, grad_func=self._grad_func, 505 python_grad_func=self._python_grad_func) 506 507 return _DefinedFunction(definition=func_def) 508 509 510class _DefinedFunction(object): 511 """Class to store the name and definition of the function defined by Defun. 512 513 This object implements a callable interface that runs `call_function`, and 514 provides a `name` property to look up the name of the `Function`. 515 516 An instance of `_DefinedFunction` may be passed to the `grad_func` parameter 517 of `define_function` and `Defun`. 518 """ 519 520 def __init__(self, definition): 521 self._definition = definition 522 523 @property 524 def name(self): 525 return self._definition.signature.name 526 527 def __call__(self, *args, **kwargs): 528 return call_function(self._definition, *args, **kwargs) 529