control_flow.py revision d9df4313a98fdc62187a94c5ab6d8955b699e9f2
15b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 25b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# 35b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License"); 45b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# you may not use this file except in compliance with the License. 55b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# You may obtain a copy of the License at 65b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# 75b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# http://www.apache.org/licenses/LICENSE-2.0 85b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# 95b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software 105b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS, 115b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 125b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# See the License for the specific language governing permissions and 135b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# limitations under the License. 145b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# ============================================================================== 15453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower"""Handles control flow statements: while, if.""" 165b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 175b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerfrom __future__ import absolute_import 185b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerfrom __future__ import division 195b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerfrom __future__ import print_function 205b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 215b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerimport gast 225b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 235b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct import anno 24d9df4313a98fdc62187a94c5ab6d8955b699e9f2A. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct import ast_util 255b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct import templates 266a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct import transformer 276a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno 285b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 295b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 305b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerclass SymbolNamer(object): 315b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower """Describes the interface for ControlFlowTransformer's namer.""" 325b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 335b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower def new_symbol(self, name_root, reserved_locals): 345b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower """Generate a new unique symbol. 355b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 365b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower Args: 375b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower name_root: String, used as stem in the new name. 385b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower reserved_locals: Set(string), additional local symbols that are reserved 395b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower and which should not be used. 405b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower Returns: 415b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower String. 425b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower """ 435b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower raise NotImplementedError() 445b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 455b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 466a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlowerclass ControlFlowTransformer(transformer.Base): 475b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower """Transforms control flow structures like loops an conditionals.""" 485b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 496a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower def __init__(self, context): 506a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower super(ControlFlowTransformer, self).__init__(context) 515b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 525b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower # pylint:disable=invalid-name 535b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 54453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower def visit_For(self, node): 55453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower assert False, 'for statement should have been canonicalized at this point' 565b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 575b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower def visit_If(self, node): 586ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower self.generic_visit(node) 596ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower 606a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) 616a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) 626ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower 636ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower if body_scope.created - orelse_scope.created: 646ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower raise ValueError( 656ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower 'The if branch creates new symbols that the else branch does not.') 666ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower if orelse_scope.created - body_scope.created: 676ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower raise ValueError( 686ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower 'The else branch creates new symbols that the if branch does not.') 696ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower 706ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower all_modified = tuple(body_scope.modified | orelse_scope.modified) 716ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower all_referenced = body_scope.referenced | orelse_scope.referenced 726ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower 736ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower # Alias the closure variables inside the conditional functions 746ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower # to avoid errors caused by the local variables created in the branch 756ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower # functions. 766ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower need_alias = ( 776ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower (body_scope.modified | orelse_scope.modified) - 786ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower (body_scope.created | orelse_scope.created)) 79949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower aliased_orig_names = tuple(need_alias) 80949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower aliased_new_names = tuple( 816a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower self.context.namer.new_symbol(s.ssf(), all_referenced) 826a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower for s in aliased_orig_names) 83949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower alias_map = dict(zip(aliased_orig_names, aliased_new_names)) 84d9df4313a98fdc62187a94c5ab6d8955b699e9f2A. Unique TensorFlower node_body = ast_util.rename_symbols(node.body, alias_map) 85d9df4313a98fdc62187a94c5ab6d8955b699e9f2A. Unique TensorFlower node_orelse = ast_util.rename_symbols(node.orelse, alias_map) 866ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower 876ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower if len(all_modified) == 1: 886a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower results = all_modified[0] 896ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower else: 906a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower results = gast.Tuple([s.ast() for s in all_modified], None) 916a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower 926a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower if aliased_orig_names: 936a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower template = """ 946a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower def body_name(): 956a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower aliased_new_names, = aliased_orig_names, 966a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body 976a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower return (all_results,) 986a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower def orelse_name(): 996a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower aliased_new_names, = aliased_orig_names, 1006a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower orelse 1016a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower return (all_results,) 1026a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower results = tf.cond(test, body_name, orelse_name) 1036a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower """ 1046a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body_name = self.context.namer.new_symbol('if_true', all_referenced) 1056a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower return templates.replace( 1066a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower template, 1076a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower test=node.test, 1086a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body_name=body_name, 1096a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body=node_body, 1106a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower orelse_name=self.context.namer.new_symbol('if_false', all_referenced), 1116a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower orelse=node_orelse, 1126a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower aliased_orig_names=tuple(aliased_orig_names), 1136a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower aliased_new_names=tuple(aliased_new_names), 1146a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower all_results=tuple(alias_map[s] if s in aliased_orig_names else s 1156a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower for s in all_modified), 1166a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower results=results) 1176a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower else: 1186a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower template = """ 1196a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower def body_name(): 1206a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body 1216a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower return (all_results,) 1226a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower def orelse_name(): 1236a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower orelse 1246a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower return (all_results,) 1256a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower results = tf.cond(test, body_name, orelse_name) 1266a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower """ 1276a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body_name = self.context.namer.new_symbol('if_true', all_referenced) 1286a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower return templates.replace( 1296a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower template, 1306a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower test=node.test, 1316a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body_name=body_name, 1326a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body=node_body, 1336a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower orelse_name=self.context.namer.new_symbol('if_false', all_referenced), 1346a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower orelse=node_orelse, 1356a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower all_results=tuple(s for s in all_modified), 1366a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower results=results) 1375b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 1385b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower def visit_While(self, node): 1395b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower self.generic_visit(node) 1406ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower 1416a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) 1426a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body_closure = body_scope.modified - body_scope.created 1436a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower all_referenced = body_scope.referenced 1446a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower 1456a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower state = list(body_closure) 1466a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower state_ssf = [ 1476a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state 1486a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower ] 1496a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower ssf_map = { 1506a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower name: ssf 1516a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower for name, ssf in zip(state, state_ssf) 1526a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower if str(name) != ssf 1536a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower } 1546a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower 1556a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower if len(state) == 1: 1566a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower state = state[0] 1576a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower state_ssf = state_ssf[0] 158453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower state_ast_tuple = state 159453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower else: 1606a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower state_ast_tuple = gast.Tuple([n.ast() for n in state], None) 1616a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower 162d9df4313a98fdc62187a94c5ab6d8955b699e9f2A. Unique TensorFlower node_body = ast_util.rename_symbols(node.body, ssf_map) 163d9df4313a98fdc62187a94c5ab6d8955b699e9f2A. Unique TensorFlower test = ast_util.rename_symbols(node.test, ssf_map) 1646a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower 165949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower template = """ 1666a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower def test_name(state_ssf): 167949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower return test 1686a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower def body_name(state_ssf): 169949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower body 1706a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower return state_ssf, 171949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower state_ast_tuple = tf.while_loop(test_name, body_name, [state]) 172949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower """ 1735b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower node = templates.replace( 1745b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower template, 175453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower state=state, 1766a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower state_ssf=state_ssf, 177453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower state_ast_tuple=state_ast_tuple, 1786a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower test_name=self.context.namer.new_symbol('loop_test', 1796a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body_scope.referenced), 1806a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower test=test, 1816a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body_name=self.context.namer.new_symbol('loop_body', 1826a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body_scope.referenced), 1836a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower body=node_body) 1845b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 1855b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower return node 1865b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 1875b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower # pylint:enable=invalid-name 1885b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 1895b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower 1906a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlowerdef transform(node, context): 1916a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower t = ControlFlowTransformer(context) 1926a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower node = t.visit(node) 1935b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower return node 194