1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Activity analysis.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22 23import gast 24 25from tensorflow.contrib.py2tf.pyct import anno 26from tensorflow.contrib.py2tf.pyct import transformer 27from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno 28 29# TODO(mdan): Add support for PY3 (e.g. Param vs arg). 30 31 32class Scope(object): 33 """Encloses local symbol definition and usage information. 34 35 This can track for instance whether a symbol is modified in the current scope. 36 Note that scopes do not necessarily align with Python's scopes. For example, 37 the body of an if statement may be considered a separate scope. 38 39 Attributes: 40 modified: identifiers modified in this scope 41 created: identifiers created in this scope 42 used: identifiers referenced in this scope 43 """ 44 45 def __init__(self, parent, isolated=True): 46 """Create a new scope. 47 48 Args: 49 parent: A Scope or None. 50 isolated: Whether the scope is isolated, that is, whether variables 51 created in this scope should be visible to the parent scope. 52 """ 53 self.isolated = isolated 54 self.parent = parent 55 self.modified = set() 56 self.created = set() 57 self.used = set() 58 self.params = set() 59 self.returned = set() 60 61 # TODO(mdan): Rename to `locals` 62 @property 63 def referenced(self): 64 if not self.isolated and self.parent is not None: 65 return self.used | self.parent.referenced 66 return self.used 67 68 def __repr__(self): 69 return 'Scope{r=%s, c=%s, w=%s}' % (tuple(self.used), tuple(self.created), 70 tuple(self.modified)) 71 72 def copy_from(self, other): 73 self.modified = copy.copy(other.modified) 74 self.created = copy.copy(other.created) 75 self.used = copy.copy(other.used) 76 self.params = copy.copy(other.params) 77 self.returned = copy.copy(other.returned) 78 79 def merge_from(self, other): 80 self.modified |= other.modified 81 self.created |= other.created 82 self.used |= other.used 83 self.params |= other.params 84 self.returned |= other.returned 85 86 def has(self, name): 87 if name in self.modified or name in self.params: 88 return True 89 elif self.parent is not None: 90 return self.parent.has(name) 91 return False 92 93 def is_modified_since_entry(self, name): 94 if name in self.modified: 95 return True 96 elif self.parent is not None and not self.isolated: 97 return self.parent.is_modified_since_entry(name) 98 return False 99 100 def is_param(self, name): 101 if name in self.params: 102 return True 103 elif self.parent is not None and not self.isolated: 104 return self.parent.is_param(name) 105 return False 106 107 def mark_read(self, name): 108 self.used.add(name) 109 if self.parent is not None and name not in self.created: 110 self.parent.mark_read(name) 111 112 def mark_param(self, name): 113 self.params.add(name) 114 115 def mark_creation(self, name): 116 if name.is_composite(): 117 parent = name.parent 118 if self.has(parent): 119 # This is considered mutation of the parent, not creation. 120 # TODO(mdan): Is that really so? 121 return 122 else: 123 raise ValueError('Unknown symbol "%s".' % parent) 124 self.created.add(name) 125 126 def mark_write(self, name): 127 self.modified.add(name) 128 if self.isolated: 129 self.mark_creation(name) 130 else: 131 if self.parent is None: 132 self.mark_creation(name) 133 else: 134 if not self.parent.has(name): 135 self.mark_creation(name) 136 self.parent.mark_write(name) 137 138 def mark_returned(self, name): 139 self.returned.add(name) 140 if not self.isolated and self.parent is not None: 141 self.parent.mark_returned(name) 142 143 144class ActivityAnalizer(transformer.Base): 145 """Annotates nodes with local scope information. See Scope.""" 146 147 def __init__(self, context, parent_scope): 148 super(ActivityAnalizer, self).__init__(context) 149 self.scope = Scope(parent_scope) 150 self._in_return_statement = False 151 152 def _track_symbol(self, node): 153 qn = anno.getanno(node, anno.Basic.QN) 154 155 if isinstance(node.ctx, gast.Store): 156 self.scope.mark_write(qn) 157 elif isinstance(node.ctx, gast.Load): 158 self.scope.mark_read(qn) 159 elif isinstance(node.ctx, gast.Param): 160 # Param contexts appear in function defs, so they have the meaning of 161 # defining a variable. 162 # TODO(mdan): This bay be incorrect with nested functions. 163 # For nested functions, we'll have to add the notion of hiding args from 164 # the parent scope, not writing to them. 165 self.scope.mark_creation(qn) 166 self.scope.mark_param(qn) 167 else: 168 raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) 169 170 anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) 171 anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY, 172 self.scope.is_modified_since_entry(qn)) 173 anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn)) 174 175 if self._in_return_statement: 176 self.scope.mark_returned(qn) 177 178 def visit_Name(self, node): 179 self.generic_visit(node) 180 self._track_symbol(node) 181 return node 182 183 def visit_Attribute(self, node): 184 self.generic_visit(node) 185 self._track_symbol(node) 186 return node 187 188 def visit_Print(self, node): 189 current_scope = self.scope 190 args_scope = Scope(current_scope) 191 self.scope = args_scope 192 for n in node.values: 193 self.visit(n) 194 anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope) 195 self.scope = current_scope 196 return node 197 198 def visit_Call(self, node): 199 current_scope = self.scope 200 args_scope = Scope(current_scope, isolated=False) 201 self.scope = args_scope 202 for n in node.args: 203 self.visit(n) 204 # TODO(mdan): Account starargs, kwargs 205 for n in node.keywords: 206 self.visit(n) 207 anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope) 208 self.scope = current_scope 209 self.visit(node.func) 210 return node 211 212 def _process_block_node(self, node, block, scope_name): 213 current_scope = self.scope 214 block_scope = Scope(current_scope, isolated=False) 215 self.scope = block_scope 216 for n in block: 217 self.visit(n) 218 anno.setanno(node, scope_name, block_scope) 219 self.scope = current_scope 220 return node 221 222 def _process_parallel_blocks(self, parent, children): 223 # Because the scopes are not isolated, processing any child block 224 # modifies the parent state causing the other child blocks to be 225 # processed incorrectly. So we need to checkpoint the parent scope so that 226 # each child sees the same context. 227 before_parent = Scope(None) 228 before_parent.copy_from(self.scope) 229 after_children = [] 230 for child, scope_name in children: 231 self.scope.copy_from(before_parent) 232 parent = self._process_block_node(parent, child, scope_name) 233 after_child = Scope(None) 234 after_child.copy_from(self.scope) 235 after_children.append(after_child) 236 for after_child in after_children: 237 self.scope.merge_from(after_child) 238 return parent 239 240 def visit_If(self, node): 241 self.visit(node.test) 242 node = self._process_parallel_blocks(node, 243 ((node.body, NodeAnno.BODY_SCOPE), 244 (node.orelse, NodeAnno.ORELSE_SCOPE))) 245 return node 246 247 def visit_For(self, node): 248 self.visit(node.target) 249 self.visit(node.iter) 250 node = self._process_parallel_blocks(node, 251 ((node.body, NodeAnno.BODY_SCOPE), 252 (node.orelse, NodeAnno.ORELSE_SCOPE))) 253 return node 254 255 def visit_While(self, node): 256 self.visit(node.test) 257 node = self._process_parallel_blocks(node, 258 ((node.body, NodeAnno.BODY_SCOPE), 259 (node.orelse, NodeAnno.ORELSE_SCOPE))) 260 return node 261 262 def visit_Return(self, node): 263 self._in_return_statement = True 264 node = self.generic_visit(node) 265 self._in_return_statement = False 266 return node 267 268 269def resolve(node, context, parent_scope=None): 270 return ActivityAnalizer(context, parent_scope).visit(node) 271