13f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
23f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower#
33f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License");
43f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# you may not use this file except in compliance with the License.
53f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# You may obtain a copy of the License at
63f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower#
73f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower#     http://www.apache.org/licenses/LICENSE-2.0
83f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower#
93f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software
103f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS,
113f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
123f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# See the License for the specific language governing permissions and
133f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# limitations under the License.
143f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower# ==============================================================================
153f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower"""Type resolution.
163f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
173f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerRequires annotations generated by LiveValuesResolver.
183f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower"""
193f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
203f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerfrom __future__ import absolute_import
213f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerfrom __future__ import division
223f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerfrom __future__ import print_function
233f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
243f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerimport gast
253f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
263f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct import anno
27f16cf555905d711c1877039e3b37240e9026c1f2A. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct import transformer
283f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerfrom tensorflow.python.util import tf_inspect
293f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
303f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
313f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlowerclass Scope(object):
323f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  """Encloses symbol value references.
333f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
343f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  Attributes:
353f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    values: A dict mapping string to gast.Node, containing the value that was
363f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower        most recently assigned to the symbol.
373f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  """
383f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
393f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def __init__(self, parent):
403f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    """Create a new scope.
413f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
423f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    Args:
433f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower      parent: A Scope or None.
443f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    """
453f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.parent = parent
463f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.values = {}
473f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
483f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def __repr__(self):
493f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return 'Scope[%s]' % self.values.keys()
503f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
513f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def copy(self):
523f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    s = Scope(self.parent)
533f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    s.values = self.values.copy()
543f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return s
553f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
563f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def setval(self, name, value):
573f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.values[name] = value
583f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
593f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def hasval(self, name):
603f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return (name in self.values or
613f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower            (self.parent is not None and self.parent.hasval(name)))
623f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
633f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def getval(self, name):
64984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    if name in self.values:
65984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower      return self.values[name]
66984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    if self.parent is not None:
67984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower      return self.parent.getval(name)
68984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    raise KeyError(name)
693f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
703f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
71f16cf555905d711c1877039e3b37240e9026c1f2A. Unique TensorFlowerclass TypeInfoResolver(transformer.Base):
723f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  """Annotates symbols with type information where possible.
733f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
743f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  Nodes currently annotated:
753f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    * Call (helps detect class constructors)
763f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    * Attribute (helps resolve object methods)
773f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  """
783f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
794153e7afff4e17bbef866bd4811b0392ddb25b53A. Unique TensorFlower  def __init__(self, context):
804153e7afff4e17bbef866bd4811b0392ddb25b53A. Unique TensorFlower    super(TypeInfoResolver, self).__init__(context)
813f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.scope = Scope(None)
823f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.function_level = 0
833f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
843f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def visit_FunctionDef(self, node):
85984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    self.scope = Scope(self.scope)
863f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.function_level += 1
873f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.generic_visit(node)
883f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.function_level -= 1
89984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    self.scope = self.scope.parent
90984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    return node
91984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower
92984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower  def _visit_block(self, block):
93984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    self.scope = Scope(self.scope)
94984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    for i, n in enumerate(block):
95984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower      block[i] = self.generic_visit(n)
96984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    self.scope = self.scope.parent
97984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    return block
98984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower
99984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower  def visit_For(self, node):
100984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    self.generic_visit(node.target)
101984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    self.generic_visit(node.iter)
102984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    node.body = self._visit_block(node.body)
103984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    node.orelse = self._visit_block(node.orelse)
104984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    return node
105984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower
106984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower  def visit_While(self, node):
107984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    self.generic_visit(node.test)
108984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    node.body = self._visit_block(node.body)
109984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    node.orelse = self._visit_block(node.orelse)
110984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    return node
111984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower
112984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower  def visit_If(self, node):
113984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    self.generic_visit(node.test)
114984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    node.body = self._visit_block(node.body)
115984bb4a5400b380f7296143042f9b45b894fcdf8A. Unique TensorFlower    node.orelse = self._visit_block(node.orelse)
1163f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return node
1173f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
11895a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower  def _process_function_arg(self, arg_name):
1196a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    str_name = str(arg_name)
1206a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    if self.function_level == 1 and str_name in self.context.arg_types:
12195a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower      # Forge a node to hold the type information, so that method calls on
12295a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower      # it can resolve the type.
1236a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      type_holder = arg_name.ast()
1246a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      type_string, type_obj = self.context.arg_types[str_name]
12595a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower      anno.setanno(type_holder, 'type', type_obj)
12695a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower      anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
12795a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower      self.scope.setval(arg_name, type_holder)
12895a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower
12995a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower  def visit_arg(self, node):
1306a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN))
13195a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower    return node
13295a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower
1333f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower  def visit_Name(self, node):
1343f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    self.generic_visit(node)
1356a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    qn = anno.getanno(node, anno.Basic.QN)
1363f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    if isinstance(node.ctx, gast.Param):
1376a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      self._process_function_arg(qn)
1386a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
13995a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower      # E.g. if we had
14095a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower      # a = b
14195a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower      # then for future references to `a` we should have traced_source = `b`
1426a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      traced_source = self.scope.getval(qn)
14395a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower      if anno.hasanno(traced_source, 'type'):
14495a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower        anno.setanno(node, 'type', anno.getanno(traced_source, 'type'))
14595a8af24058c168ce8a5327451e1cfcbc56461ebA. Unique TensorFlower        anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn'))
1463f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return node
1473f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
14855cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower  def _process_variable_assignment(self, source, targets):
14955cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    if isinstance(source, gast.Call):
15055cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower      func = source.func
15155cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower      if anno.hasanno(func, 'live_val'):
15255cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower        func_obj = anno.getanno(func, 'live_val')
15355cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower        if tf_inspect.isclass(func_obj):
154cb71a0a0bbecadafbeba82580b7cb8a26ac33a38A. Unique TensorFlower          anno.setanno(source, 'is_constructor', True)
15555cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower          anno.setanno(source, 'type', func_obj)
15655cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower          anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
1573f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower          # TODO(mdan): Raise an error if constructor has side effects.
1583f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower          # We can have a whitelist of no-side-effects constructors.
1593f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower          # We can also step inside the constructor and further analyze.
1603f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
16155cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    for t in targets:
16255cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower      if isinstance(t, gast.Tuple):
16355cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower        for i, e in enumerate(t.elts):
1646a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          self.scope.setval(
1656a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower              anno.getanno(e, anno.Basic.QN),
1666a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower              gast.Subscript(source, gast.Index(i), ctx=gast.Store()))
1676a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      elif isinstance(t, (gast.Name, gast.Attribute)):
1686a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        self.scope.setval(anno.getanno(t, anno.Basic.QN), source)
169cb71a0a0bbecadafbeba82580b7cb8a26ac33a38A. Unique TensorFlower      else:
170cb71a0a0bbecadafbeba82580b7cb8a26ac33a38A. Unique TensorFlower        raise ValueError('Dont know how to handle assignment to %s' % t)
1713f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
17255cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower  def visit_With(self, node):
17355cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    for wi in node.items:
17455cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower      if wi.optional_vars is not None:
17555cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower        self._process_variable_assignment(wi.context_expr, (wi.optional_vars,))
17655cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    self.generic_visit(node)
17755cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    return node
17855cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower
17955cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower  def visit_Assign(self, node):
18055cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    self.generic_visit(node)
18155cd506ab8220c6a1075965eb7839cac4af1db3eA. Unique TensorFlower    self._process_variable_assignment(node.value, node.targets)
1823f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower    return node
1833f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
1843f0506007a39b72dc7b06e2fc9df1dca75146f9cA. Unique TensorFlower
1854153e7afff4e17bbef866bd4811b0392ddb25b53A. Unique TensorFlowerdef resolve(node, context):
1864153e7afff4e17bbef866bd4811b0392ddb25b53A. Unique TensorFlower  return TypeInfoResolver(context).visit(node)
187