1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Class to hold a library of OpDefs and use it to create Brain operations."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import six
23
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.core.framework import op_def_pb2
26from tensorflow.core.framework import tensor_pb2
27from tensorflow.core.framework import tensor_shape_pb2
28from tensorflow.core.framework import types_pb2
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.util import compat
34from tensorflow.python.util import tf_contextlib
35
36
37def _Attr(op_def, name):
38  for attr in op_def.attr:
39    if attr.name == name:
40      return attr
41  raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" %
42                  (op_def.name, name))
43
44
45def _AttrValue(attr_protos, name):
46  if name in attr_protos:
47    return attr_protos[name]
48  raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." %
49                  (name, attr_protos))
50
51
52def _SatisfiesTypeConstraint(dtype, attr_def, param_name):
53  if attr_def.HasField("allowed_values"):
54    allowed_list = attr_def.allowed_values.list.type
55    if dtype not in allowed_list:
56      raise TypeError(
57          "Value passed to parameter '%s' has DataType %s not in list of "
58          "allowed values: %s" %
59          (param_name, dtypes.as_dtype(dtype).name,
60           ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
61
62
63def _IsListParameter(arg):
64  if arg.number_attr:
65    return True
66  elif arg.type_list_attr:
67    return True
68  return False
69
70
71def _NumTypeFields(arg):
72  num = 0
73  if arg.type != types_pb2.DT_INVALID: num += 1
74  if arg.type_attr: num += 1
75  if arg.type_list_attr: num += 1
76  return num
77
78
79def _IsListValue(v):
80  return isinstance(v, (list, tuple))
81
82
83def _Flatten(l):
84  """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5]."""
85  # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]]
86  l_of_l = [x if _IsListValue(x) else [x] for x in l]
87  # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5]
88  return [item for sublist in l_of_l for item in sublist]
89
90
91def _Restructure(l, structure):
92  """Returns the elements of list l structured according to the given structure.
93
94  A structure is represented by a list whose elements are either
95  `None` or a non-negative integer. `None` corresponds to a single
96  element in the output list, and an integer N corresponds to a nested
97  list of length N.
98
99  The function returns a data structure whose shape is given by
100  `structure`, and whose elements are taken from `l`. If `structure`
101  is a singleton, the function returns the single data structure
102  implied by the 0th element of `structure`. For example:
103
104      _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None])
105        -> ["foo", ["bar", "baz"], "qux"]
106
107      _Restructure(["foo"], [None]) -> "foo"
108
109      _Restructure(["foo"], [1]) -> ["foo"]
110
111      _Restructure([], [0]) -> []
112
113  Args:
114    l: A list.
115    structure: A list whose elements are either `None` or a non-negative
116      integer.
117
118  Returns:
119    The elements of `l`, restructured according to `structure`. If
120    `structure` is a list of length 1, this function returns the
121    single data structure implied by `structure[0]`.
122
123  """
124  result = []
125  current_index = 0
126  for element in structure:
127    if element is None:
128      result.append(l[current_index])
129      current_index += 1
130    else:
131      result.append(l[current_index:current_index+element])
132      current_index += element
133
134  if len(result) == 1:
135    return result[0]
136  else:
137    return tuple(result)
138
139
140def _MakeFloat(v, arg_name):
141  if not isinstance(v, compat.real_types):
142    raise TypeError("Expected float for argument '%s' not %s." %
143                    (arg_name, repr(v)))
144  return float(v)
145
146
147def _MakeInt(v, arg_name):
148  if isinstance(v, six.string_types):
149    raise TypeError("Expected int for argument '%s' not %s." %
150                    (arg_name, repr(v)))
151  try:
152    return int(v)
153  except (ValueError, TypeError):
154    raise TypeError("Expected int for argument '%s' not %s." %
155                    (arg_name, repr(v)))
156
157
158def _MakeStr(v, arg_name):
159  if not isinstance(v, compat.bytes_or_text_types):
160    raise TypeError("Expected string for argument '%s' not %s." %
161                    (arg_name, repr(v)))
162  return compat.as_bytes(v)  # Convert unicode strings to bytes.
163
164
165def _MakeBool(v, arg_name):
166  if not isinstance(v, bool):
167    raise TypeError("Expected bool for argument '%s' not %s." %
168                    (arg_name, repr(v)))
169  return v
170
171
172def _MakeType(v, attr_def):
173  try:
174    v = dtypes.as_dtype(v).base_dtype
175  except TypeError:
176    raise TypeError("Expected DataType for argument '%s' not %s." %
177                    (attr_def.name, repr(v)))
178  i = v.as_datatype_enum
179  _SatisfiesTypeConstraint(i, attr_def, param_name=attr_def.name)
180  return i
181
182
183def _MakeShape(v, arg_name):
184  """Convert v into a TensorShapeProto."""
185  # Args:
186  #   v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
187  #   arg_name: String, for error messages.
188
189  # Returns:
190  #   A TensorShapeProto.
191  if isinstance(v, tensor_shape_pb2.TensorShapeProto):
192    for d in v.dim:
193      if d.name:
194        logging.warning("Warning: TensorShapeProto with a named dimension: %s",
195                        str(v))
196        break
197    return v
198  try:
199    return tensor_shape.as_shape(v).as_proto()
200  except TypeError as e:
201    raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e))
202  except ValueError as e:
203    raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e))
204
205
206def _MakeTensor(v, arg_name):
207  """Ensure v is a TensorProto."""
208  if isinstance(v, tensor_pb2.TensorProto):
209    return v
210  raise TypeError(
211      "Don't know how to convert %s to a TensorProto for argument '%s'" %
212      (repr(v), arg_name))
213
214
215class _OpInfo(object):
216  """All per-Op state we would like to precompute/validate."""
217
218  def __init__(self, op_def):
219    self.op_def = op_def
220    # TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it
221    # here, instead of these checks.
222    for arg in list(op_def.input_arg) + list(op_def.output_arg):
223      num_type_fields = _NumTypeFields(arg)
224      if num_type_fields != 1:
225        raise TypeError("Arg '%s' of '%s' must have one type field not %d" %
226                        (arg.name, op_def.name, num_type_fields))
227      if arg.type_attr:
228        attr_type = _Attr(op_def, arg.type_attr).type
229        if attr_type != "type":
230          raise TypeError("Attr '%s' of '%s' used as a type_attr "
231                          "but has type %s" %
232                          (arg.type_attr, op_def.name, attr_type))
233      if arg.type_list_attr:
234        attr_type = _Attr(op_def, arg.type_list_attr).type
235        if attr_type != "list(type)":
236          raise TypeError(
237              "Attr '%s' of '%s' used as a type_list_attr but has type %s" %
238              (arg.type_attr, op_def.name, attr_type))
239      if arg.number_attr:
240        attr_type = _Attr(op_def, arg.number_attr).type
241        if attr_type != "int":
242          raise TypeError(
243              "Attr '%s' of '%s' used as a number_attr but has type %s" %
244              (arg.number_attr, op_def.name, attr_type))
245
246
247# pylint: disable=g-doc-return-or-yield
248@tf_contextlib.contextmanager
249def _MaybeColocateWith(inputs):
250  """A context manager for (maybe) colocating with a list of input tensors.
251
252  Args:
253    inputs: A list of `Tensor` or `Operation` objects.
254
255  Returns:
256    A context manager.
257  """
258  if not inputs:
259    yield
260  else:
261    # NOTE(mrry): The `ops.colocate_with()` function accepts only a single
262    # op or tensor, so we create one context manager per element in the list.
263    with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]):
264      yield
265# pylint: enable=g-doc-return-or-yield
266
267
268class OpDefLibrary(object):
269  """Holds a collection of OpDefs, can add the corresponding Ops to a graph."""
270
271  def __init__(self):
272    self._ops = {}
273
274  # pylint: disable=invalid-name
275  def add_op(self, op_def):
276    """Register an OpDef. May call apply_op with the name afterwards."""
277    if not isinstance(op_def, op_def_pb2.OpDef):
278      raise TypeError("%s is %s, not an op_def_pb2.OpDef" %
279                      (op_def, type(op_def)))
280    if not op_def.name:
281      raise ValueError("%s missing name." % op_def)
282    if op_def.name in self._ops:
283      raise RuntimeError("Op name %s registered twice." % op_def.name)
284    self._ops[op_def.name] = _OpInfo(op_def)
285
286  def add_op_list(self, op_list):
287    """Register the OpDefs from an OpList."""
288    if not isinstance(op_list, op_def_pb2.OpList):
289      raise TypeError("%s is %s, not an op_def_pb2.OpList" %
290                      (op_list, type(op_list)))
291    for op_def in op_list.op:
292      self.add_op(op_def)
293
294  def apply_op(self, op_type_name, name=None, **keywords):
295    # pylint: disable=g-doc-args
296    """Add a node invoking a registered Op to a graph.
297
298    Example usage:
299       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
300       # will convert to a Tensor.
301       op_def_library.apply_op("op", input1=input1, input2=input2)
302       # Can specify a node name.
303       op_def_library.apply_op("op", input1=input1, name="node_name")
304       # Must use keyword arguments, with the names specified in the OpDef.
305       op_def_library.apply_op("op", input_name=input, attr_name=attr)
306
307    All attrs must either be inferred from an input or specified.
308    (If inferred, the attr must not be specified.)  If an attr has a default
309    value specified in the Op's OpDef, then you may pass None as the value
310    of that attr to get the default.
311
312    Args:
313      op_type_name: string. Must match the name field of a registered Op.
314      name: string. Optional name of the created op.
315      **keywords: input Tensor and attr arguments specified by name,
316        and optional parameters to pass when constructing the Operation.
317
318    Returns:
319      The Tensor(s) representing the output of the operation, or the Operation
320      itself if there are no outputs.
321
322    Raises:
323      RuntimeError: On some errors.
324      TypeError: On some errors.
325      ValueError: On some errors.
326    """
327    output_structure, is_stateful, op = self._apply_op_helper(
328        op_type_name, name, **keywords)
329    if output_structure:
330      outputs = op.outputs
331      res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
332      if isinstance(res, list) and not res and is_stateful:
333        return op
334      else:
335        return res
336    else:
337      return op
338
339  def _apply_op_helper(self, op_type_name, name=None, **keywords):
340    """Implementation of apply_op that returns output_structure, op."""
341    op_info = self._ops.get(op_type_name, None)
342    if op_info is None:
343      raise RuntimeError("Unrecognized Op name " + op_type_name)
344    op_def = op_info.op_def
345
346    # Determine the graph context.
347    try:
348      # Need to flatten all the arguments into a list.
349      # pylint: disable=protected-access
350      g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
351      # pylint: enable=protected-access
352    except AssertionError as e:
353      raise RuntimeError(
354          "Cannot determine graph for Op '%s' due to: %s"
355          % (op_type_name, e.message))
356
357    # Default name if not specified.
358    if name is None:
359      name = op_type_name
360
361    # Check for deprecation
362    deprecation_version = op_def.deprecation.version
363    if deprecation_version:
364      producer = g.graph_def_versions.producer
365      if producer >= deprecation_version:
366        raise NotImplementedError(
367            ("Op %s is not available in GraphDef version %d. "
368             "It has been removed in version %d. %s.") %
369            (op_type_name, producer, deprecation_version,
370             op_def.deprecation.explanation))
371
372    # Fill in the list of default types for all "type" attrs.  This
373    # will be used to choose a preferred dtype to convert to in the
374    # absence of input type information.
375    #
376    # TODO(b/31302892): Currently the defaults don't work in the right
377    # way if you have two inputs, one of whose type resolution depends
378    # on the other.  Handling this will require restructuring this code
379    # significantly.
380    default_type_attr_map = {}
381    for attr_def in op_def.attr:
382      if attr_def.type != "type":
383        continue
384      key = attr_def.name
385      if attr_def.HasField("default_value"):
386        default_type_attr_map[key] = dtypes.as_dtype(
387            attr_def.default_value.type)
388
389    # Requires that op_def has passed validation (using the C++
390    # ValidateOpDef() from ../framework/op_def_util.h).
391    attrs = {}
392    inputs = []
393    input_types = []
394    with g.as_default(), ops.name_scope(name) as scope:
395
396      # Perform input type inference
397      inferred_from = {}
398      for input_arg in op_def.input_arg:
399        input_name = input_arg.name
400        if input_name in keywords:
401          values = keywords.pop(input_name)
402        elif input_name + "_" in keywords:
403          # Handle the case where the name is a keyword or built-in
404          # for Python so we use the name + _ instead.
405          input_name += "_"
406          values = keywords.pop(input_name)
407        else:
408          raise TypeError("No argument for input " + input_name)
409
410        # Goals:
411        # * Convert values to Tensors if it contains constants.
412        # * Verify that values is a list if that matches the input_arg's
413        #   type.
414        # * If the input_arg's type is determined by attrs, either set
415        #   those attrs and validate those attr values are legal (if
416        #   they have not yet been set) or validate the input matches
417        #   the type indicated by the attrs (if they have already been
418        #   inferred via an earlier input).
419        # * If the input_arg has an explicit type, make sure the input
420        #   conforms.
421
422        if _IsListParameter(input_arg):
423          if not _IsListValue(values):
424            raise TypeError(
425                "Expected list for '%s' argument to '%s' Op, not %s." %
426                (input_name, op_type_name, values))
427          # In cases where we expect all elements of the list to have the
428          # same dtype, try to cast non-Tensor elements to that type.
429          dtype = None
430          default_dtype = None
431          if input_arg.type != types_pb2.DT_INVALID:
432            dtype = input_arg.type
433          elif input_arg.number_attr:
434            if input_arg.type_attr in attrs:
435              dtype = attrs[input_arg.type_attr]
436            else:
437              for t in values:
438                if isinstance(t, ops.Tensor):
439                  dtype = t.dtype
440                  break
441
442            # dtype still not found, prefer using the default dtype
443            # from the attr.
444            if dtype is None and input_arg.type_attr in default_type_attr_map:
445              default_dtype = default_type_attr_map[input_arg.type_attr]
446
447          try:
448            if not input_arg.is_ref and dtype:
449              dtype = dtypes.as_dtype(dtype).base_dtype
450            values = ops.internal_convert_n_to_tensor(
451                values,
452                name=input_arg.name,
453                dtype=dtype if dtype else None,
454                preferred_dtype=default_dtype,
455                as_ref=input_arg.is_ref)
456            if input_arg.number_attr and len(
457                set(v.dtype.base_dtype for v in values)) > 1:
458              raise TypeError()  # All types should match.
459          except (TypeError, ValueError):
460            # What types does the conversion function think values have?
461            observed_types = []
462            for value in values:
463              try:
464                converted_value = ops.internal_convert_to_tensor(
465                    value, as_ref=input_arg.is_ref)
466                observed_types.append(converted_value.dtype.base_dtype.name)
467              except (TypeError, ValueError):
468                observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
469            observed = ", ".join(observed_types)
470
471            prefix = (
472                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
473                (input_name, op_type_name, observed))
474            if input_arg.number_attr:
475              if input_arg.type != types_pb2.DT_INVALID:
476                raise TypeError("%s that do not match expected type %s." %
477                                (prefix, dtype.name))
478              elif input_arg.type_attr in attrs:
479                raise TypeError("%s that do not match type %s inferred from "
480                                "earlier arguments." %
481                                (prefix, dtype.name))
482              else:
483                raise TypeError("%s that don't all match." % prefix)
484            else:
485              raise TypeError("%s that are invalid." % prefix)
486
487          types = [x.dtype for x in values]
488          inputs.extend(values)
489        else:
490          # In cases where we have an expected type, try to convert non-Tensor
491          # arguments to that type.
492          dtype = None
493          default_dtype = None
494          if input_arg.type != types_pb2.DT_INVALID:
495            dtype = input_arg.type
496          elif input_arg.type_attr in attrs:
497            dtype = attrs[input_arg.type_attr]
498          elif input_arg.type_attr in default_type_attr_map:
499            # The dtype could not be inferred solely from the inputs,
500            # so we prefer the attr's default, so code that adds a new attr
501            # with a default is backwards compatible.
502            default_dtype = default_type_attr_map[input_arg.type_attr]
503
504          try:
505            values = ops.internal_convert_to_tensor(
506                values,
507                name=input_arg.name,
508                dtype=dtype,
509                as_ref=input_arg.is_ref,
510                preferred_dtype=default_dtype)
511          except TypeError as err:
512            if dtype is None:
513              raise err
514            else:
515              raise TypeError(
516                  "Expected %s passed to parameter '%s' of op '%s', got %s of "
517                  "type '%s' instead." %
518                  (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name,
519                   repr(values), type(values).__name__))
520          except ValueError:
521            # What type does convert_to_tensor think it has?
522            try:
523              observed = ops.internal_convert_to_tensor(
524                  values, as_ref=input_arg.is_ref).dtype.name
525            except ValueError as err:
526              raise ValueError(
527                  "Tried to convert '%s' to a tensor and failed. Error: %s" %
528                  (input_name, err))
529            prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
530                      (input_name, op_type_name, observed))
531            if input_arg.type != types_pb2.DT_INVALID:
532              raise TypeError("%s expected type of %s." %
533                              (prefix, dtypes.as_dtype(input_arg.type).name))
534            else:
535              # Update the maps with the default, if needed.
536              k = input_arg.type_attr
537              if k in default_type_attr_map:
538                if k not in attrs:
539                  attrs[k] = default_type_attr_map[k]
540                  if k not in inferred_from:
541                    inferred_from[k] = "Default in OpDef"
542
543              raise TypeError(
544                  "%s type %s of argument '%s'." %
545                  (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
546                   inferred_from[input_arg.type_attr]))
547
548          types = [values.dtype]
549          inputs.append(values)
550        base_types = [x.base_dtype for x in types]
551
552        if input_arg.number_attr:
553          # <number-attr> * <type> or <number-attr> * <type-attr>
554          if input_arg.number_attr in attrs:
555            if len(values) != attrs[input_arg.number_attr]:
556              raise ValueError(
557                  "List argument '%s' to '%s' Op with length %d must match "
558                  "length %d of argument '%s'." %
559                  (input_name, op_type_name, len(values),
560                   attrs[input_arg.number_attr],
561                   inferred_from[input_arg.number_attr]))
562          else:
563            attrs[input_arg.number_attr] = len(values)
564            inferred_from[input_arg.number_attr] = input_name
565            num_attr = _Attr(op_def, input_arg.number_attr)
566            if num_attr.has_minimum and len(values) < num_attr.minimum:
567              raise ValueError(
568                  "List argument '%s' to '%s' Op with length %d shorter "
569                  "than minimum length %d." %
570                  (input_name, op_type_name, len(values), num_attr.minimum))
571          # All tensors must have the same base type.
572          if any([bt != base_types[0] for bt in base_types]):
573            raise TypeError(
574                "All tensors passed to '%s' of '%s' Op "
575                "must have the same type." %
576                (input_name, op_type_name))
577          if input_arg.type != types_pb2.DT_INVALID:
578            # <number-attr> * <type> case
579            if base_types and base_types[0] != input_arg.type:
580              assert False, "Unreachable"
581          elif input_arg.type_attr in attrs:
582            # <number-attr> * <type-attr> case, where <type-attr> already
583            # has an inferred value.
584            if base_types and base_types[0] != attrs[input_arg.type_attr]:
585              assert False, "Unreachable"
586          else:
587            # <number-attr> * <type-attr> case, where we are now setting
588            # the <type-attr> based on this input
589            if not base_types:
590              raise TypeError(
591                  "Don't know how to infer type variable from empty input "
592                  "list passed to input '%s' of '%s' Op." %
593                  (input_name, op_type_name))
594            attrs[input_arg.type_attr] = base_types[0]
595            inferred_from[input_arg.type_attr] = input_name
596            type_attr = _Attr(op_def, input_arg.type_attr)
597            _SatisfiesTypeConstraint(base_types[0], type_attr,
598                                     param_name=input_name)
599        elif input_arg.type_attr:
600          # <type-attr>
601          attr_value = base_types[0]
602          if input_arg.type_attr in attrs:
603            if attrs[input_arg.type_attr] != attr_value:
604              assert False, "Unreachable"
605          else:
606            for base_type in base_types:
607              _SatisfiesTypeConstraint(base_type,
608                                       _Attr(op_def, input_arg.type_attr),
609                                       param_name=input_name)
610            attrs[input_arg.type_attr] = attr_value
611            inferred_from[input_arg.type_attr] = input_name
612        elif input_arg.type_list_attr:
613          # <type-list-attr>
614          attr_value = base_types
615          if input_arg.type_list_attr in attrs:
616            if attrs[input_arg.type_list_attr] != attr_value:
617              raise TypeError(
618                  "Input '%s' of '%s' Op has type list of %s that does not "
619                  "match type list %s of argument '%s'." %
620                  (input_name, op_type_name,
621                   ", ".join(dtypes.as_dtype(x).name for x in attr_value),
622                   ", ".join(dtypes.as_dtype(x).name
623                             for x in attrs[input_arg.type_list_attr]),
624                   inferred_from[input_arg.type_list_attr]))
625          else:
626            for base_type in base_types:
627              _SatisfiesTypeConstraint(base_type,
628                                       _Attr(op_def, input_arg.type_list_attr),
629                                       param_name=input_name)
630            attrs[input_arg.type_list_attr] = attr_value
631            inferred_from[input_arg.type_list_attr] = input_name
632        else:
633          # single Tensor with specified type
634          if base_types[0] != input_arg.type:
635            assert False, "Unreachable"
636
637        if input_arg.is_ref:
638          if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
639            raise TypeError(
640                ("'%s' Op requires that input '%s' be a mutable tensor "
641                 "(e.g.: a tf.Variable)") % (op_type_name, input_name))
642          input_types.extend(types)
643        else:
644          input_types.extend(base_types)
645
646      # Process remaining attrs
647      for attr in op_def.attr:
648        # Skip attrs that have already had their values inferred
649        if attr.name in attrs:
650          if attr.name in keywords:
651            raise TypeError(
652                "Should not specify value for inferred attr '%s'." % attr.name)
653          continue
654        if attr.name in keywords:
655          attrs[attr.name] = keywords.pop(attr.name)
656        elif attr.name + "_" in keywords:
657          # Attrs whose names match Python keywords have an extra '_'
658          # appended, so we must check for that as well.
659          attrs[attr.name] = keywords.pop(attr.name + "_")
660        else:
661          raise TypeError("No argument for attr " + attr.name)
662
663      # Convert attr values to AttrValue protos.
664      attr_protos = {}
665      for attr_def in op_def.attr:
666        key = attr_def.name
667        value = attrs[key]
668        attr_value = attr_value_pb2.AttrValue()
669        if attr_def.HasField("default_value") and value is None:
670          attr_value.CopyFrom(attr_def.default_value)
671          attr_protos[key] = attr_value
672          continue
673        if attr_def.type.startswith("list("):
674          if not _IsListValue(value):
675            raise TypeError("Expected list for attr " + key)
676          if attr_def.has_minimum:
677            if len(value) < attr_def.minimum:
678              raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
679                               "less than minimum %d." %
680                               (key, op_type_name, len(value),
681                                attr_def.minimum))
682          attr_value.list.SetInParent()
683        if attr_def.type == "string":
684          attr_value.s = _MakeStr(value, key)
685          if attr_def.HasField("allowed_values"):
686            if attr_value.s not in attr_def.allowed_values.list.s:
687              raise ValueError(
688                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
689                  (key, op_type_name, compat.as_text(attr_value.s),
690                   '", "'.join(map(compat.as_text,
691                                   attr_def.allowed_values.list.s))))
692        elif attr_def.type == "list(string)":
693          attr_value.list.s.extend([_MakeStr(x, key) for x in value])
694          if attr_def.HasField("allowed_values"):
695            for x in attr_value.list.s:
696              if x not in attr_def.allowed_values.list.s:
697                raise ValueError(
698                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
699                    (key, op_type_name, compat.as_text(x),
700                     '", "'.join(map(compat.as_text,
701                                     attr_def.allowed_values.list.s))))
702        elif attr_def.type == "int":
703          attr_value.i = _MakeInt(value, key)
704          if attr_def.has_minimum:
705            if attr_value.i < attr_def.minimum:
706              raise ValueError(
707                  "Attr '%s' of '%s' Op passed %d less than minimum %d." %
708                  (key, op_type_name, attr_value.i, attr_def.minimum))
709        elif attr_def.type == "list(int)":
710          attr_value.list.i.extend([_MakeInt(x, key) for x in value])
711        elif attr_def.type == "float":
712          attr_value.f = _MakeFloat(value, key)
713        elif attr_def.type == "list(float)":
714          attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
715        elif attr_def.type == "bool":
716          attr_value.b = _MakeBool(value, key)
717        elif attr_def.type == "list(bool)":
718          attr_value.list.b.extend([_MakeBool(x, key) for x in value])
719        elif attr_def.type == "type":
720          attr_value.type = _MakeType(value, attr_def)
721        elif attr_def.type == "list(type)":
722          attr_value.list.type.extend(
723              [_MakeType(x, attr_def) for x in value])
724        elif attr_def.type == "shape":
725          attr_value.shape.CopyFrom(_MakeShape(value, key))
726        elif attr_def.type == "list(shape)":
727          attr_value.list.shape.extend(
728              [_MakeShape(x, key) for x in value])
729        elif attr_def.type == "tensor":
730          attr_value.tensor.CopyFrom(_MakeTensor(value, key))
731        elif attr_def.type == "list(tensor)":
732          attr_value.list.tensor.extend(
733              [_MakeTensor(x, key) for x in value])
734        elif attr_def.type == "func":
735          if isinstance(value, attr_value_pb2.NameAttrList):
736            attr_value.func.CopyFrom(value)
737          elif isinstance(value, compat.bytes_or_text_types):
738            attr_value.func.name = value
739          else:
740            value.add_to_graph(ops.get_default_graph())
741            attr_value.func.name = value.name
742        else:
743          raise TypeError("Unrecognized Attr type " + attr_def.type)
744
745        attr_protos[key] = attr_value
746      del attrs  # attrs is no longer authoritative, use attr_protos instead
747
748      # Determine output types (possibly using attrs)
749      output_types = []
750      output_structure = []
751      for arg in op_def.output_arg:
752        types = []
753        if arg.number_attr:
754          n = _AttrValue(attr_protos, arg.number_attr).i
755          if arg.type_attr:
756            types = [_AttrValue(attr_protos, arg.type_attr).type] * n
757          else:
758            types = [arg.type] * n
759          output_structure.append(n)
760        elif arg.type_attr:
761          t = _AttrValue(attr_protos, arg.type_attr)
762          types = [t.type]
763          output_structure.append(None)
764        elif arg.type_list_attr:
765          t = _AttrValue(attr_protos, arg.type_list_attr)
766          types = t.list.type
767          output_structure.append(len(types))
768        else:
769          types = [arg.type]
770          output_structure.append(None)
771        if arg.is_ref:
772          types = [dtypes.as_dtype(x)._as_ref for x in types]  # pylint: disable=protected-access
773        output_types.extend(types)
774
775      if keywords:
776        raise TypeError("apply_op() got unexpected keyword arguments: " +
777                        ", ".join(sorted(keywords.keys())))
778
779      # NOTE(mrry): We add an explicit colocation constraint between
780      # the newly created op and any of its reference-typed inputs.
781      must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
782                              if arg.is_ref]
783      with _MaybeColocateWith(must_colocate_inputs):
784        # Add Op to graph
785        op = g.create_op(op_type_name, inputs, output_types, name=scope,
786                         input_types=input_types, attrs=attr_protos,
787                         op_def=op_def)
788      return output_structure, op_def.is_stateful, op
789
790# pylint: enable=invalid-name
791