1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type resolution.
16
17Requires annotations generated by LiveValuesResolver.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import gast
25
26from tensorflow.contrib.py2tf.pyct import anno
27from tensorflow.contrib.py2tf.pyct import transformer
28from tensorflow.python.util import tf_inspect
29
30
31class Scope(object):
32  """Encloses symbol value references.
33
34  Attributes:
35    values: A dict mapping string to gast.Node, containing the value that was
36        most recently assigned to the symbol.
37  """
38
39  def __init__(self, parent):
40    """Create a new scope.
41
42    Args:
43      parent: A Scope or None.
44    """
45    self.parent = parent
46    self.values = {}
47
48  def __repr__(self):
49    return 'Scope[%s]' % self.values.keys()
50
51  def copy(self):
52    s = Scope(self.parent)
53    s.values = self.values.copy()
54    return s
55
56  def setval(self, name, value):
57    self.values[name] = value
58
59  def hasval(self, name):
60    return (name in self.values or
61            (self.parent is not None and self.parent.hasval(name)))
62
63  def getval(self, name):
64    if name in self.values:
65      return self.values[name]
66    if self.parent is not None:
67      return self.parent.getval(name)
68    raise KeyError(name)
69
70
71class TypeInfoResolver(transformer.Base):
72  """Annotates symbols with type information where possible.
73
74  Nodes currently annotated:
75    * Call (helps detect class constructors)
76    * Attribute (helps resolve object methods)
77  """
78
79  def __init__(self, context):
80    super(TypeInfoResolver, self).__init__(context)
81    self.scope = Scope(None)
82    self.function_level = 0
83
84  def visit_FunctionDef(self, node):
85    self.scope = Scope(self.scope)
86    self.function_level += 1
87    self.generic_visit(node)
88    self.function_level -= 1
89    self.scope = self.scope.parent
90    return node
91
92  def _visit_block(self, block):
93    self.scope = Scope(self.scope)
94    for i, n in enumerate(block):
95      block[i] = self.generic_visit(n)
96    self.scope = self.scope.parent
97    return block
98
99  def visit_For(self, node):
100    self.generic_visit(node.target)
101    self.generic_visit(node.iter)
102    node.body = self._visit_block(node.body)
103    node.orelse = self._visit_block(node.orelse)
104    return node
105
106  def visit_While(self, node):
107    self.generic_visit(node.test)
108    node.body = self._visit_block(node.body)
109    node.orelse = self._visit_block(node.orelse)
110    return node
111
112  def visit_If(self, node):
113    self.generic_visit(node.test)
114    node.body = self._visit_block(node.body)
115    node.orelse = self._visit_block(node.orelse)
116    return node
117
118  def _process_function_arg(self, arg_name):
119    str_name = str(arg_name)
120    if self.function_level == 1 and str_name in self.context.arg_types:
121      # Forge a node to hold the type information, so that method calls on
122      # it can resolve the type.
123      type_holder = arg_name.ast()
124      type_string, type_obj = self.context.arg_types[str_name]
125      anno.setanno(type_holder, 'type', type_obj)
126      anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
127      self.scope.setval(arg_name, type_holder)
128
129  def visit_arg(self, node):
130    self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN))
131    return node
132
133  def visit_Name(self, node):
134    self.generic_visit(node)
135    qn = anno.getanno(node, anno.Basic.QN)
136    if isinstance(node.ctx, gast.Param):
137      self._process_function_arg(qn)
138    elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
139      # E.g. if we had
140      # a = b
141      # then for future references to `a` we should have traced_source = `b`
142      traced_source = self.scope.getval(qn)
143      if anno.hasanno(traced_source, 'type'):
144        anno.setanno(node, 'type', anno.getanno(traced_source, 'type'))
145        anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn'))
146    return node
147
148  def _process_variable_assignment(self, source, targets):
149    if isinstance(source, gast.Call):
150      func = source.func
151      if anno.hasanno(func, 'live_val'):
152        func_obj = anno.getanno(func, 'live_val')
153        if tf_inspect.isclass(func_obj):
154          anno.setanno(source, 'is_constructor', True)
155          anno.setanno(source, 'type', func_obj)
156          anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
157          # TODO(mdan): Raise an error if constructor has side effects.
158          # We can have a whitelist of no-side-effects constructors.
159          # We can also step inside the constructor and further analyze.
160
161    for t in targets:
162      if isinstance(t, gast.Tuple):
163        for i, e in enumerate(t.elts):
164          self.scope.setval(
165              anno.getanno(e, anno.Basic.QN),
166              gast.Subscript(source, gast.Index(i), ctx=gast.Store()))
167      elif isinstance(t, (gast.Name, gast.Attribute)):
168        self.scope.setval(anno.getanno(t, anno.Basic.QN), source)
169      else:
170        raise ValueError('Dont know how to handle assignment to %s' % t)
171
172  def visit_With(self, node):
173    for wi in node.items:
174      if wi.optional_vars is not None:
175        self._process_variable_assignment(wi.context_expr, (wi.optional_vars,))
176    self.generic_visit(node)
177    return node
178
179  def visit_Assign(self, node):
180    self.generic_visit(node)
181    self._process_variable_assignment(node.value, node.targets)
182    return node
183
184
185def resolve(node, context):
186  return TypeInfoResolver(context).visit(node)
187