type_info.py revision 55cd506ab8220c6a1075965eb7839cac4af1db3e
13f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
23f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower#
33f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License");
43f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# you may not use this file except in compliance with the License.
53f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# You may obtain a copy of the License at
63f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower#
73f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower#     http://www.apache.org/licenses/LICENSE-2.0
83f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower#
93f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software
103f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS,
113f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# See the License for the specific language governing permissions and
133f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# limitations under the License.
143f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# ==============================================================================
153f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower"""Type resolution.
163f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
173f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerRequires annotations generated by LiveValuesResolver.
183f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower"""
193f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
203f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerfrom __future__ import absolute_import
213f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerfrom __future__ import division
223f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerfrom __future__ import print_function
233f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
243f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerimport gast
253f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
263f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct import anno
273f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerfrom tensorflow.python.util import tf_inspect
283f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
293f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
303f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerclass Scope(object):
313f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  """Encloses symbol value references.
323f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
333f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  Attributes:
343f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    values: A dict mapping string to gast.Node, containing the value that was
353f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        most recently assigned to the symbol.
363f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  """
373f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
383f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  # TODO(mdan): Should rather use a CFG here?
393f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
403f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def __init__(self, parent):
413f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    """Create a new scope.
423f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
433f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    Args:
443f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      parent: A Scope or None.
453f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    """
463f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.parent = parent
473f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.values = {}
483f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
493f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def __repr__(self):
503f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return 'Scope[%s]' % self.values.keys()
513f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
523f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def copy(self):
533f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    s = Scope(self.parent)
543f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    s.values = self.values.copy()
553f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return s
563f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
573f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def setval(self, name, value):
583f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.values[name] = value
593f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
603f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def hasval(self, name):
613f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return (name in self.values or
623f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower            (self.parent is not None and self.parent.hasval(name)))
633f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
643f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def getval(self, name):
653f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return self.values[name]
663f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
673f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
683f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerclass TypeInfoResolver(gast.NodeTransformer):
693f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  """Annotates symbols with type information where possible.
703f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
713f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  Nodes currently annotated:
723f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    * Call (helps detect class constructors)
733f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    * Attribute (helps resolve object methods)
743f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  """
753f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
763f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def __init__(self, value_hints):
773f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.scope = Scope(None)
783f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.value_hints = value_hints
793f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.function_level = 0
803f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
813f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def visit_FunctionDef(self, node):
823f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.function_level += 1
833f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.generic_visit(node)
843f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.function_level -= 1
853f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return node
863f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
873f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def visit_Name(self, node):
883f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.generic_visit(node)
893f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    if isinstance(node.ctx, gast.Param):
903f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      self.scope.setval(node.id, gast.Name(node.id, gast.Load(), None))
913f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      if (self.function_level == 1 and self.value_hints is not None and
923f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower          node.id in self.value_hints):
933f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        # Forge a node to hold the type information, so that method calls on
943f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        # it can resolve the type.
953f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        type_holder = gast.Name(node.id, gast.Load(), None)
963f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        type_string, type_obj = self.value_hints[node.id]
973f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        anno.setanno(type_holder, 'type', type_obj)
983f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        anno.setanno(type_holder, 'type_fqn', type_string.split('.'))
993f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        self.scope.setval(node.id, type_holder)
1003f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return node
1013f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
10255cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower  def _process_variable_assignment(self, source, targets):
10355cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    if isinstance(source, gast.Call):
10455cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower      func = source.func
10555cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower      if anno.hasanno(func, 'live_val'):
10655cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower        func_obj = anno.getanno(func, 'live_val')
10755cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower        if tf_inspect.isclass(func_obj):
1083f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower          # This is then a constructor.
10955cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower          anno.setanno(source, 'type', func_obj)
11055cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower          anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
1113f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower          # TODO(mdan): Raise an error if constructor has side effects.
1123f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower          # We can have a whitelist of no-side-effects constructors.
1133f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower          # We can also step inside the constructor and further analyze.
1143f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
11555cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    for t in targets:
11655cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower      if isinstance(t, gast.Tuple):
11755cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower        for i, e in enumerate(t.elts):
1183f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower          self.scope.setval(e.id,
1193f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower                            gast.Subscript(
12055cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower                                source, gast.Index(i), ctx=gast.Store()))
1213f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      else:
12255cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower        self.scope.setval(t.id, source)
1233f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
12455cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower  def visit_With(self, node):
12555cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    for wi in node.items:
12655cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower      if wi.optional_vars is not None:
12755cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower        self._process_variable_assignment(wi.context_expr, (wi.optional_vars,))
12855cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    self.generic_visit(node)
12955cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    return node
13055cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower
13155cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower  def visit_Assign(self, node):
13255cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    self.generic_visit(node)
13355cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    self._process_variable_assignment(node.value, node.targets)
1343f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return node
1353f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
1363f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def visit_Call(self, node):
1373f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    target = node.func
1383f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    if not anno.hasanno(target, 'live_val'):
1393f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      if not isinstance(target, gast.Attribute):
1403f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        # Suspecting this pattern would reach here:
1413f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        #   foo = bar
1423f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        #   foo()
1433f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        raise ValueError('Dont know how to handle dynamic functions.')
1443f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      if not isinstance(target.value, gast.Name):
1453f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        # Possible example of this kind:
1463f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        #   foo = module.Foo()
1473f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        #   foo.bar.baz()
1483f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        # TODO(mdan): This should be doable by using the FQN.
1493f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        raise ValueError('Dont know how to handle object properties yet.')
1503f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      # In the example below, object_source is 'tr.train.Optimizer()':
1513f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      #   opt = tf.train.Optimizer()
1523f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      #   opt.foo()
1533f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      object_source = self.scope.getval(target.value.id)
1543f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      if not anno.hasanno(object_source, 'type'):
1553f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        raise ValueError('Could not determine type of "%s". Is it dynamic?' %
1563f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower                         (target.value.id))
1573f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      anno.setanno(target, 'type_fqn', anno.getanno(object_source, 'type_fqn'))
1583f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.generic_visit(node)
1593f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return node
1603f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
1613f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def visit_While(self, node):
1623f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    anno.setanno(node, 'parent_scope_values', self.scope.copy())
1633f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.generic_visit(node)
1643f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return node
1653f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
1663f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
1673f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerdef resolve(node, value_hints):
1683f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  return TypeInfoResolver(value_hints).visit(node)
169