control_flow.py revision d9df4313a98fdc62187a94c5ab6d8955b699e9f2
15b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
25b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower#
35b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License");
45b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# you may not use this file except in compliance with the License.
55b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# You may obtain a copy of the License at
65b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower#
75b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower#     http://www.apache.org/licenses/LICENSE-2.0
85b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower#
95b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software
105b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS,
115b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# See the License for the specific language governing permissions and
135b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# limitations under the License.
145b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower# ==============================================================================
15453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower"""Handles control flow statements: while, if."""
165b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
175b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerfrom __future__ import absolute_import
185b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerfrom __future__ import division
195b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerfrom __future__ import print_function
205b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
215b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerimport gast
225b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
235b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct import anno
24d9df4313a98fdc62187a94c5ab6d8955b699e9f2A. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct import ast_util
255b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct import templates
266a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct import transformer
276a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlowerfrom tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
285b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
295b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
305b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlowerclass SymbolNamer(object):
315b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower  """Describes the interface for ControlFlowTransformer's namer."""
325b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
335b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower  def new_symbol(self, name_root, reserved_locals):
345b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower    """Generate a new unique symbol.
355b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
365b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower    Args:
375b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower      name_root: String, used as stem in the new name.
385b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower      reserved_locals: Set(string), additional local symbols that are reserved
395b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower          and which should not be used.
405b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower    Returns:
415b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower      String.
425b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower    """
435b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower    raise NotImplementedError()
445b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
455b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
466a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlowerclass ControlFlowTransformer(transformer.Base):
475b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower  """Transforms control flow structures like loops an conditionals."""
485b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
496a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower  def __init__(self, context):
506a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    super(ControlFlowTransformer, self).__init__(context)
515b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
525b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower  # pylint:disable=invalid-name
535b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
54453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower  def visit_For(self, node):
55453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower    assert False, 'for statement should have been canonicalized at this point'
565b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
575b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower  def visit_If(self, node):
586ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower    self.generic_visit(node)
596ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower
606a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
616a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)
626ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower
636ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower    if body_scope.created - orelse_scope.created:
646ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower      raise ValueError(
656ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower          'The if branch creates new symbols that the else branch does not.')
666ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower    if orelse_scope.created - body_scope.created:
676ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower      raise ValueError(
686ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower          'The else branch creates new symbols that the if branch does not.')
696ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower
706ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower    all_modified = tuple(body_scope.modified | orelse_scope.modified)
716ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower    all_referenced = body_scope.referenced | orelse_scope.referenced
726ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower
736ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower    # Alias the closure variables inside the conditional functions
746ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower    # to avoid errors caused by the local variables created in the branch
756ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower    # functions.
766ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower    need_alias = (
776ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower        (body_scope.modified | orelse_scope.modified) -
786ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower        (body_scope.created | orelse_scope.created))
79949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower    aliased_orig_names = tuple(need_alias)
80949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower    aliased_new_names = tuple(
816a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        self.context.namer.new_symbol(s.ssf(), all_referenced)
826a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        for s in aliased_orig_names)
83949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower    alias_map = dict(zip(aliased_orig_names, aliased_new_names))
84d9df4313a98fdc62187a94c5ab6d8955b699e9f2A. Unique TensorFlower    node_body = ast_util.rename_symbols(node.body, alias_map)
85d9df4313a98fdc62187a94c5ab6d8955b699e9f2A. Unique TensorFlower    node_orelse = ast_util.rename_symbols(node.orelse, alias_map)
866ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower
876ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower    if len(all_modified) == 1:
886a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      results = all_modified[0]
896ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower    else:
906a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      results = gast.Tuple([s.ast() for s in all_modified], None)
916a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower
926a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    if aliased_orig_names:
936a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      template = """
946a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        def body_name():
956a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          aliased_new_names, = aliased_orig_names,
966a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          body
976a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          return (all_results,)
986a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        def orelse_name():
996a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          aliased_new_names, = aliased_orig_names,
1006a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          orelse
1016a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          return (all_results,)
1026a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        results = tf.cond(test, body_name, orelse_name)
1036a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      """
1046a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      body_name = self.context.namer.new_symbol('if_true', all_referenced)
1056a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      return templates.replace(
1066a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          template,
1076a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          test=node.test,
1086a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          body_name=body_name,
1096a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          body=node_body,
1106a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
1116a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          orelse=node_orelse,
1126a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          aliased_orig_names=tuple(aliased_orig_names),
1136a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          aliased_new_names=tuple(aliased_new_names),
1146a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          all_results=tuple(alias_map[s] if s in aliased_orig_names else s
1156a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower                            for s in all_modified),
1166a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          results=results)
1176a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    else:
1186a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      template = """
1196a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        def body_name():
1206a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          body
1216a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          return (all_results,)
1226a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        def orelse_name():
1236a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          orelse
1246a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          return (all_results,)
1256a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        results = tf.cond(test, body_name, orelse_name)
1266a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      """
1276a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      body_name = self.context.namer.new_symbol('if_true', all_referenced)
1286a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      return templates.replace(
1296a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          template,
1306a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          test=node.test,
1316a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          body_name=body_name,
1326a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          body=node_body,
1336a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
1346a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          orelse=node_orelse,
1356a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          all_results=tuple(s for s in all_modified),
1366a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower          results=results)
1375b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
1385b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower  def visit_While(self, node):
1395b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower    self.generic_visit(node)
1406ee404d17929c613b217400406e7e665010ebf18A. Unique TensorFlower
1416a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
1426a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    body_closure = body_scope.modified - body_scope.created
1436a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    all_referenced = body_scope.referenced
1446a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower
1456a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    state = list(body_closure)
1466a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    state_ssf = [
1476a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
1486a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    ]
1496a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    ssf_map = {
1506a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        name: ssf
1516a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        for name, ssf in zip(state, state_ssf)
1526a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        if str(name) != ssf
1536a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    }
1546a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower
1556a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower    if len(state) == 1:
1566a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      state = state[0]
1576a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      state_ssf = state_ssf[0]
158453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower      state_ast_tuple = state
159453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower    else:
1606a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
1616a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower
162d9df4313a98fdc62187a94c5ab6d8955b699e9f2A. Unique TensorFlower    node_body = ast_util.rename_symbols(node.body, ssf_map)
163d9df4313a98fdc62187a94c5ab6d8955b699e9f2A. Unique TensorFlower    test = ast_util.rename_symbols(node.test, ssf_map)
1646a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower
165949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower    template = """
1666a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      def test_name(state_ssf):
167949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower        return test
1686a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower      def body_name(state_ssf):
169949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower        body
1706a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        return state_ssf,
171949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower      state_ast_tuple = tf.while_loop(test_name, body_name, [state])
172949dd29d3a8bdc21328c9e94721b344310686eabA. Unique TensorFlower    """
1735b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower    node = templates.replace(
1745b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower        template,
175453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower        state=state,
1766a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        state_ssf=state_ssf,
177453f94412f908aadd21561c14feae80dfac1e933A. Unique TensorFlower        state_ast_tuple=state_ast_tuple,
1786a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        test_name=self.context.namer.new_symbol('loop_test',
1796a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower                                                body_scope.referenced),
1806a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        test=test,
1816a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        body_name=self.context.namer.new_symbol('loop_body',
1826a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower                                                body_scope.referenced),
1836a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower        body=node_body)
1845b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
1855b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower    return node
1865b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
1875b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower  # pylint:enable=invalid-name
1885b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
1895b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower
1906a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlowerdef transform(node, context):
1916a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower  t = ControlFlowTransformer(context)
1926a822c373818948037baacfbae1c7355e0fc2c48A. Unique TensorFlower  node = t.visit(node)
1935b2aae39b75f5a864e0ec0dd95c7f3a07e9d16e7A. Unique TensorFlower  return node
194