1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Activity analysis."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22
23import gast
24
25from tensorflow.contrib.py2tf.pyct import anno
26from tensorflow.contrib.py2tf.pyct import transformer
27from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
28
29# TODO(mdan): Add support for PY3 (e.g. Param vs arg).
30
31
32class Scope(object):
33  """Encloses local symbol definition and usage information.
34
35  This can track for instance whether a symbol is modified in the current scope.
36  Note that scopes do not necessarily align with Python's scopes. For example,
37  the body of an if statement may be considered a separate scope.
38
39  Attributes:
40    modified: identifiers modified in this scope
41    created: identifiers created in this scope
42    used: identifiers referenced in this scope
43  """
44
45  def __init__(self, parent, isolated=True):
46    """Create a new scope.
47
48    Args:
49      parent: A Scope or None.
50      isolated: Whether the scope is isolated, that is, whether variables
51          created in this scope should be visible to the parent scope.
52    """
53    self.isolated = isolated
54    self.parent = parent
55    self.modified = set()
56    self.created = set()
57    self.used = set()
58    self.params = set()
59    self.returned = set()
60
61  # TODO(mdan): Rename to `locals`
62  @property
63  def referenced(self):
64    if not self.isolated and self.parent is not None:
65      return self.used | self.parent.referenced
66    return self.used
67
68  def __repr__(self):
69    return 'Scope{r=%s, c=%s, w=%s}' % (tuple(self.used), tuple(self.created),
70                                        tuple(self.modified))
71
72  def copy_from(self, other):
73    self.modified = copy.copy(other.modified)
74    self.created = copy.copy(other.created)
75    self.used = copy.copy(other.used)
76    self.params = copy.copy(other.params)
77    self.returned = copy.copy(other.returned)
78
79  def merge_from(self, other):
80    self.modified |= other.modified
81    self.created |= other.created
82    self.used |= other.used
83    self.params |= other.params
84    self.returned |= other.returned
85
86  def has(self, name):
87    if name in self.modified or name in self.params:
88      return True
89    elif self.parent is not None:
90      return self.parent.has(name)
91    return False
92
93  def is_modified_since_entry(self, name):
94    if name in self.modified:
95      return True
96    elif self.parent is not None and not self.isolated:
97      return self.parent.is_modified_since_entry(name)
98    return False
99
100  def is_param(self, name):
101    if name in self.params:
102      return True
103    elif self.parent is not None and not self.isolated:
104      return self.parent.is_param(name)
105    return False
106
107  def mark_read(self, name):
108    self.used.add(name)
109    if self.parent is not None and name not in self.created:
110      self.parent.mark_read(name)
111
112  def mark_param(self, name):
113    self.params.add(name)
114
115  def mark_creation(self, name):
116    if name.is_composite():
117      parent = name.parent
118      if self.has(parent):
119        # This is considered mutation of the parent, not creation.
120        # TODO(mdan): Is that really so?
121        return
122      else:
123        raise ValueError('Unknown symbol "%s".' % parent)
124    self.created.add(name)
125
126  def mark_write(self, name):
127    self.modified.add(name)
128    if self.isolated:
129      self.mark_creation(name)
130    else:
131      if self.parent is None:
132        self.mark_creation(name)
133      else:
134        if not self.parent.has(name):
135          self.mark_creation(name)
136        self.parent.mark_write(name)
137
138  def mark_returned(self, name):
139    self.returned.add(name)
140    if not self.isolated and self.parent is not None:
141      self.parent.mark_returned(name)
142
143
144class ActivityAnalizer(transformer.Base):
145  """Annotates nodes with local scope information. See Scope."""
146
147  def __init__(self, context, parent_scope):
148    super(ActivityAnalizer, self).__init__(context)
149    self.scope = Scope(parent_scope)
150    self._in_return_statement = False
151
152  def _track_symbol(self, node):
153    qn = anno.getanno(node, anno.Basic.QN)
154
155    if isinstance(node.ctx, gast.Store):
156      self.scope.mark_write(qn)
157    elif isinstance(node.ctx, gast.Load):
158      self.scope.mark_read(qn)
159    elif isinstance(node.ctx, gast.Param):
160      # Param contexts appear in function defs, so they have the meaning of
161      # defining a variable.
162      # TODO(mdan): This bay be incorrect with nested functions.
163      # For nested functions, we'll have to add the notion of hiding args from
164      # the parent scope, not writing to them.
165      self.scope.mark_creation(qn)
166      self.scope.mark_param(qn)
167    else:
168      raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))
169
170    anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))
171    anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY,
172                 self.scope.is_modified_since_entry(qn))
173    anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn))
174
175    if self._in_return_statement:
176      self.scope.mark_returned(qn)
177
178  def visit_Name(self, node):
179    self.generic_visit(node)
180    self._track_symbol(node)
181    return node
182
183  def visit_Attribute(self, node):
184    self.generic_visit(node)
185    self._track_symbol(node)
186    return node
187
188  def visit_Print(self, node):
189    current_scope = self.scope
190    args_scope = Scope(current_scope)
191    self.scope = args_scope
192    for n in node.values:
193      self.visit(n)
194    anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
195    self.scope = current_scope
196    return node
197
198  def visit_Call(self, node):
199    current_scope = self.scope
200    args_scope = Scope(current_scope, isolated=False)
201    self.scope = args_scope
202    for n in node.args:
203      self.visit(n)
204    # TODO(mdan): Account starargs, kwargs
205    for n in node.keywords:
206      self.visit(n)
207    anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
208    self.scope = current_scope
209    self.visit(node.func)
210    return node
211
212  def _process_block_node(self, node, block, scope_name):
213    current_scope = self.scope
214    block_scope = Scope(current_scope, isolated=False)
215    self.scope = block_scope
216    for n in block:
217      self.visit(n)
218    anno.setanno(node, scope_name, block_scope)
219    self.scope = current_scope
220    return node
221
222  def _process_parallel_blocks(self, parent, children):
223    # Because the scopes are not isolated, processing any child block
224    # modifies the parent state causing the other child blocks to be
225    # processed incorrectly. So we need to checkpoint the parent scope so that
226    # each child sees the same context.
227    before_parent = Scope(None)
228    before_parent.copy_from(self.scope)
229    after_children = []
230    for child, scope_name in children:
231      self.scope.copy_from(before_parent)
232      parent = self._process_block_node(parent, child, scope_name)
233      after_child = Scope(None)
234      after_child.copy_from(self.scope)
235      after_children.append(after_child)
236    for after_child in after_children:
237      self.scope.merge_from(after_child)
238    return parent
239
240  def visit_If(self, node):
241    self.visit(node.test)
242    node = self._process_parallel_blocks(node,
243                                         ((node.body, NodeAnno.BODY_SCOPE),
244                                          (node.orelse, NodeAnno.ORELSE_SCOPE)))
245    return node
246
247  def visit_For(self, node):
248    self.visit(node.target)
249    self.visit(node.iter)
250    node = self._process_parallel_blocks(node,
251                                         ((node.body, NodeAnno.BODY_SCOPE),
252                                          (node.orelse, NodeAnno.ORELSE_SCOPE)))
253    return node
254
255  def visit_While(self, node):
256    self.visit(node.test)
257    node = self._process_parallel_blocks(node,
258                                         ((node.body, NodeAnno.BODY_SCOPE),
259                                          (node.orelse, NodeAnno.ORELSE_SCOPE)))
260    return node
261
262  def visit_Return(self, node):
263    self._in_return_statement = True
264    node = self.generic_visit(node)
265    self._in_return_statement = False
266    return node
267
268
269def resolve(node, context, parent_scope=None):
270  return ActivityAnalizer(context, parent_scope).visit(node)
271